derektan commited on
Commit
4f09ecf
·
1 Parent(s): 66c5745

Init new app to handle planning. Fresh import from 27fe831777c12b25e504dd14e5b661742bdecce6 from VLM-Search

Browse files
.gitignore CHANGED
@@ -1 +1,146 @@
1
- **/__pycache__/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Additional Stuff
2
+ !train/
3
+ train/*
4
+ !train/saved
5
+
6
+ !model/
7
+ model/*
8
+ !model/saved
9
+
10
+ !inference/
11
+ inference/*
12
+ !inference/saved
13
+
14
+ !gifs/
15
+ gifs/*
16
+ !gifs/saved
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ __pycache__/
20
+ *.py[cod]
21
+ *$py.class
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ pip-wheel-metadata/
41
+ share/python-wheels/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+ MANIFEST
46
+
47
+ # PyInstaller
48
+ # Usually these files are written by a python script from a template
49
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
50
+ *.manifest
51
+ *.spec
52
+
53
+ # Installer logs
54
+ pip-log.txt
55
+ pip-delete-this-directory.txt
56
+
57
+ # Unit test / coverage reports
58
+ htmlcov/
59
+ .tox/
60
+ .nox/
61
+ .coverage
62
+ .coverage.*
63
+ .cache
64
+ nosetests.xml
65
+ coverage.xml
66
+ *.cover
67
+ *.py,cover
68
+ .hypothesis/
69
+ .pytest_cache/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ target/
93
+
94
+ # Jupyter Notebook
95
+ .ipynb_checkpoints
96
+
97
+ # IPython
98
+ profile_default/
99
+ ipython_config.py
100
+
101
+ # pyenv
102
+ .python-version
103
+
104
+ # pipenv
105
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
107
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
108
+ # install all needed dependencies.
109
+ #Pipfile.lock
110
+
111
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112
+ __pypackages__/
113
+
114
+ # Celery stuff
115
+ celerybeat-schedule
116
+ celerybeat.pid
117
+
118
+ # SageMath parsed files
119
+ *.sage.py
120
+
121
+ # Environments
122
+ .env
123
+ .venv
124
+ env/
125
+ venv/
126
+ ENV/
127
+ env.bak/
128
+ venv.bak/
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # mkdocs documentation
138
+ /site
139
+
140
+ # mypy
141
+ .mypy_cache/
142
+ .dmypy.json
143
+ dmypy.json
144
+
145
+ # Pyre type checker
146
+ .pyre/
Maps/flair_real_maps/envs_val_trained_clss/MSK_000310_great_horned_owl.png ADDED

Git LFS Details

  • SHA256: 7af11bcef1972b7e047f53b597fef2a332d82c7feceb21aac6e14a57469c436b
  • Pointer size: 129 Bytes
  • Size of remote file: 2.34 kB
Maps/flair_real_maps/masks_val_trained_clss/MSK_000310_great_horned_owl.png ADDED

Git LFS Details

  • SHA256: 4e2c4a5be42c630a1ea5f65300e9d9a1eeed746be89befd8e1adde17b80d18e5
  • Pointer size: 129 Bytes
  • Size of remote file: 6.13 kB
Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3/MSK_000310_great_horned_owl.png ADDED

Git LFS Details

  • SHA256: 318773e2c18275d84b5145d7e69836baa0bedd833f44b49f98e6619357677cff
  • Pointer size: 130 Bytes
  • Size of remote file: 75.9 kB
Taxabind/Taxabind/SatBind/clip_seg_tta.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import cycle
2
+ import sys
3
+ import os
4
+
5
+ import cv2
6
+
7
+ # Ensure the correct directory is at the front of sys.path
8
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
9
+ # Remove cached 'model' module to force reloading from the correct path
10
+ if "model" in sys.modules:
11
+ del sys.modules["model"]
12
+
13
+ import json
14
+ import glob
15
+ from datetime import datetime
16
+ import time
17
+
18
+ import torch
19
+ import numpy as np
20
+ from PIL import Image
21
+ from tqdm import tqdm
22
+ from matplotlib import pyplot as plt
23
+ from torch.utils.data import DataLoader
24
+ from torchvision.transforms import v2
25
+
26
+ import open_clip
27
+ from dataset import SatNatDataset
28
+ from model import SatBind
29
+ from transformers import CLIPVisionModelWithProjection
30
+ from kmeans_clustering import CombinedSilhouetteInertiaClusterer
31
+
32
+ # Just to import SoundBind (1 dir back)
33
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
34
+ from SoundBind.model import AudioBind
35
+
36
+ # import matplotlib
37
+ # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
38
+
39
+ class ClipSegTTA:
40
+ def __init__(
41
+ self,
42
+ img_dir: str,
43
+ imo_dir: str,
44
+ json_path: str,
45
+ sat_to_img_ids_json_path: str,
46
+ patch_size: int,
47
+ sat_checkpoint_path: str,
48
+ sample_index: int = 0,
49
+ blur_kernel = (5,5), # (0,0) for no gaussian blur
50
+ batch_size: int = 1,
51
+ num_workers: int = 1,
52
+ device: str = "cuda",
53
+ sat_to_img_ids_json_is_train_dict: bool = True,
54
+ tax_to_filter_val: str = "",
55
+ load_model: bool = True,
56
+ initial_modality: str = "image",
57
+ sound_data_path: str = None,
58
+ sound_checkpoint_path: str = None,
59
+ ):
60
+ """
61
+ Initialize the ClipSegTTA class with the required parameters.
62
+
63
+ :param img_dir: Path to the ground images directory.
64
+ :param imo_dir: Path to the satellite images directory.
65
+ :param json_path: Path to the iNat JSON file (full dataset).
66
+ :param sat_filtered_json_path: Path to the filtered pixel-CLIP JSON (satellite).
67
+ :param sat_to_img_ids_json_path: Path to the mapping from satellite image IDs to ground image IDs.
68
+ :param patch_size: Size of each satellite patch (e.g. 14).
69
+ :param sat_checkpoint_path: Path to the SatBind checkpoint (CKPT file).
70
+ :param batch_size: Batch size for loading data (default=1).
71
+ :param num_workers: Number of workers for data loading (default=1).
72
+ :param device: Device to use, e.g., 'cuda' or 'cpu' (default='cuda').
73
+ """
74
+
75
+ self.img_dir = img_dir
76
+ self.imo_dir = imo_dir
77
+ self.json_path = json_path
78
+ # self.sat_filtered_json_path = sat_filtered_json_path
79
+ self.sat_to_img_ids_json_path = sat_to_img_ids_json_path
80
+ self.patch_size = patch_size
81
+ self.sat_checkpoint_path = sat_checkpoint_path
82
+ self.sample_index = sample_index # Can be overriden by reset function
83
+ self.blur_kernel = blur_kernel
84
+ self.batch_size = batch_size
85
+ self.num_workers = num_workers
86
+ self.device = device
87
+ self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
88
+ self.tax_to_filter_val = tax_to_filter_val
89
+ self.load_model = load_model
90
+ self.initial_modality = initial_modality
91
+ self.sound_data_path = sound_data_path
92
+ self.sound_checkpoint_path = sound_checkpoint_path
93
+
94
+ # Prepare the dataset
95
+ start_time = time.time()
96
+ self.load_data()
97
+ print(f"Dataset loaded in {(time.time()-start_time):.2f}s.")
98
+
99
+ # Load the global model (original/frozen checkpoint)
100
+ if self.load_model:
101
+ start_time = time.time()
102
+ self.load_global_model()
103
+ self.tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
104
+ print(f"Global model loaded in {(time.time()-start_time):.2f}s.")
105
+
106
+ # Create the local model that will be adapted for TTA
107
+ if self.load_model:
108
+ self.model_local = SatBind(train_dataset=None, val_dataset=None)
109
+ self.model_local.to(self.device)
110
+ self.model_local.eval()
111
+
112
+ # Load sound model if provided
113
+ if self.sound_checkpoint_path:
114
+ self.sound_model = AudioBind.load_from_checkpoint(sound_checkpoint_path, train_dataset=None, val_dataset=None)
115
+ self.sound_model.to(self.device)
116
+ self.sound_model.eval()
117
+
118
+ # Params
119
+ self.reset(sample_idx=self.sample_index)
120
+
121
+ # for idx, related_img_path in enumerate(self.related_imgs_paths):
122
+ # self.visualize_heatmap(
123
+ # step=0,
124
+ # img_path_viz=related_img_path,
125
+ # imo_path_viz=self.imo_path,
126
+ # patch_idx_viz=[self.patch_idx],
127
+ # patch_is_pos=[True],
128
+ # species_name=self.species_name
129
+ # )
130
+
131
+ self.clip_inference_time = 0.0
132
+ self.tta_time = 0.0
133
+
134
+
135
+ def load_data(self):
136
+ """Load or initialize the dataset."""
137
+ self.dataset = SatNatDataset(
138
+ img_dir=self.img_dir,
139
+ imo_dir=self.imo_dir,
140
+ json_path=self.json_path,
141
+ # sat_filtered_json_path=self.sat_filtered_json_path,
142
+ sat_to_img_ids_json_path=self.sat_to_img_ids_json_path,
143
+ sound_data_path=self.sound_data_path,
144
+ patch_size=self.patch_size,
145
+ mode="val",
146
+ get_img_path=True,
147
+ sat_to_img_ids_json_is_train_dict=self.sat_to_img_ids_json_is_train_dict,
148
+ tax_to_filter_val=self.tax_to_filter_val
149
+ )
150
+
151
+ def reset(self, sample_idx):
152
+ """Reset the parameters & local model for the current sample."""
153
+ if self.load_model:
154
+ self.reset_local_model() # Reset to global weights as init
155
+
156
+ self.img_paths, self.imo_path, self.imgs, self.imo, self.sounds, self.sound_ids, self.species_name, self.target_positions, self.gt_mask_name = self.dataset.get_search_ds_data(sample_idx)
157
+ self.imgs = self.imgs.to(self.device)
158
+ # self.img_path = self.img_paths[0] # Select 1st img as query
159
+ # self.img = self.imgs[0] # Select 1st img as query
160
+ # self.img = self.img.to(self.device)
161
+ # self.imo = self.imo.to(self.device)
162
+ self.tgts_gt_score = None
163
+ if self.load_model:
164
+ self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds = None, None, None, None
165
+ img = self.imgs[0].unsqueeze(0).to(self.device)
166
+ imo = self.imo.unsqueeze(0).to(self.device)
167
+ sound = self.sounds[0].to(self.device) if self.sounds != [] else None
168
+ txt = [self.species_name]
169
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=self.initial_modality)
170
+
171
+ # Find avg heatmap score for target positions (index target into heatmap)
172
+ scores = []
173
+ imo_orig = Image.open(self.imo_path)
174
+ for pos in self.target_positions:
175
+ row_trans = int(pos[0] * self.heatmap.shape[0] / imo_orig.size[0])
176
+ col_trans = int(pos[1] * self.heatmap.shape[1] / imo_orig.size[1])
177
+ # print("row_trans, col_trans: ", row_trans, col_trans)
178
+ scores.append(self.heatmap[row_trans, col_trans])
179
+ self.tgts_gt_score = np.mean(scores)
180
+ # print("scores: ", scores)
181
+ # print("self.tgts_gt_score: ", self.tgts_gt_score)
182
+
183
+
184
+ # print("Sample reset for TTA: ", self.species_name)
185
+
186
+ def load_global_model(self):
187
+ """Load the global SatBind model from checkpoint, move to device, and eval."""
188
+ self.model_global = SatBind.load_from_checkpoint(
189
+ self.sat_checkpoint_path, train_dataset=None, val_dataset=None
190
+ )
191
+ self.model_global = self.model_global.to(self.device)
192
+ self.model_global.eval()
193
+
194
+ def reset_local_model(self):
195
+ """
196
+ Reset the local model to match the global model's parameters
197
+ and freeze/unfreeze layers for TTA.
198
+ """
199
+ start_time = time.time()
200
+ with torch.no_grad():
201
+ for param_global, param_local in zip(
202
+ self.model_global.parameters(), self.model_local.parameters()
203
+ ):
204
+ param_local.data.copy_(param_global.data)
205
+
206
+ # Freeze everything except the satellite encoder & custom projection
207
+ for name, param in self.model_local.named_parameters():
208
+ if "imo_encoder" in name or "visual_projection_custom" in name:
209
+ param.requires_grad = True
210
+ else:
211
+ param.requires_grad = False
212
+
213
+ self.model_local.eval()
214
+ # print(f"Local model reset in {(time.time()-start_time):.2f}s.")
215
+
216
+
217
+ def execute_tta(
218
+ self,
219
+ patch_indices: list,
220
+ patch_is_pos: list,
221
+ pos_sample_weight: list,
222
+ neg_sample_weight: list,
223
+ tta_steps: int = 10,
224
+ lr: float = 2e-6,
225
+ modality: str = "image", # image, text, or combined
226
+ query_variety: bool = False, # Whether to use different query images
227
+ target_found_idxs: list = [],
228
+ reset_weights: bool = True,
229
+ num_viz_steps: int = 1,
230
+ viz_heatmap: bool = False,
231
+ ):
232
+ """
233
+ Run test-time adaptation using the local model. The local model is first
234
+ reset to the global weights. After TTA, the global model remains
235
+ unchanged; only the local model is updated.
236
+
237
+ :param sample_index: Index for selecting sample(s) from the validation set.
238
+ :param tta_steps: Number of test-time adaptation steps.
239
+ :param num_viz_steps: Visualize heatmap after every 'num_viz_steps' steps.
240
+ :param viz_heatmap: If True, perform visualization. If False, skip plotting.
241
+ """
242
+
243
+ ### Option 1: SAMPLE FROM DATASET
244
+ # 1) Reset the local model to global weights
245
+ if reset_weights:
246
+ self.reset_local_model()
247
+
248
+ # 2) Prepare the sample(s) for TTA
249
+ # print("target_found_idxs: ", target_found_idxs)
250
+ if query_variety:
251
+ indices = torch.tensor(target_found_idxs).to(dtype=torch.long).to(self.device)
252
+ img = self.imgs[indices].to(self.device)
253
+ # print("~variety")
254
+ elif (self.initial_modality == "text" or self.initial_modality == "sound") and modality == "combined" and len(target_found_idxs) == 0:
255
+ indices = torch.tensor(target_found_idxs).to(dtype=torch.long).to(self.device)
256
+ img = self.imgs[indices].to(self.device)
257
+ # print("~empty")
258
+ else:
259
+ img = self.imgs[0].unsqueeze(0).to(self.device)
260
+ # print("~single")
261
+ imo = self.imo.unsqueeze(0).to(self.device) # vectorize in shared_step to make imo uniform (faster)
262
+ txt = [self.species_name]
263
+ sound = self.sounds[0].to(self.device) if self.sounds != [] else None
264
+ patch_indices = [idx+1 for idx in patch_indices] # BUGFIX: Consider the [CLS] token offset
265
+ patch_idx = torch.tensor(patch_indices).to(self.device)
266
+
267
+ # print("img.shape: ", img.shape)
268
+ # print("imo.shape: ", imo.shape)
269
+ # print("patch_idx.shape: ", patch_idx.shape)
270
+ # print("species_txt: ", txt)
271
+ # ---------------------------------------------------------------------
272
+
273
+
274
+ # 5) Set up optimizer (only for sat branch) & AMP scaler
275
+ optimizer = torch.optim.Adam(
276
+ [p for p in self.model_local.parameters() if p.requires_grad], lr=lr
277
+ )
278
+ # scaler = torch.cuda.amp.GradScaler(init_scale=1000)
279
+
280
+ # print(
281
+ # f"Starting TTA (patch_idx={patch_idx[0].item()}) for {tta_steps} steps. "
282
+ # f"Visualization: {'ON' if viz_heatmap else 'OFF'} every {num_viz_steps} step(s)."
283
+ # )
284
+
285
+ start_time = time.time()
286
+
287
+ # 6) TTA loop
288
+ for step in range(tta_steps):
289
+ # # with torch.cuda.amp.autocast():
290
+ batch_size = imo.shape[0]
291
+
292
+ # Query embeds
293
+ query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality)
294
+
295
+ # Sat Embeds
296
+ imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
297
+ imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
298
+ imo_embeds = self.model_local.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
299
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
300
+
301
+ # Compute Similarity Loss
302
+ logit_scale = self.model_local.logit_scale.exp()
303
+ similarity = imo_embeds @ query_embeds.t() * logit_scale
304
+ # target = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device)
305
+ # per_sample_weight = torch.tensor(per_sample_weight, dtype=torch.float32, device=similarity.device)
306
+ # criterion = torch.nn.BCEWithLogitsLoss(weight=per_sample_weight) # Includes sigmoid activation internally
307
+ # loss = criterion(similarity.squeeze(), target)
308
+ # # NOTE: Equivalent loss scaling
309
+ # # criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
310
+ # # loss = criterion(similarity.squeeze(), target)
311
+ # # loss = loss * per_sample_weight
312
+ # # loss = loss.mean()
313
+
314
+ # loss = 0
315
+ # patch_probs = similarity.squeeze()
316
+ # for probs, counts in zip(patch_probs, patch_is_pos):
317
+ # print("patch_probs.shape: ", patch_probs.shape)
318
+ # delta_loss = (probs - counts * torch.log(probs + 1e-6))
319
+ # print("delta_loss.shape: ", delta_loss.shape)
320
+ # loss += delta_loss
321
+
322
+ # Negative Log Likelihood loss for spatial poisson point process
323
+ patch_probs = similarity.squeeze().sigmoid()
324
+ counts = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device)
325
+ pos_weights = torch.tensor(pos_sample_weight, dtype=torch.float32, device=similarity.device)
326
+ neg_weights = torch.tensor(neg_sample_weight, dtype=torch.float32, device=similarity.device)
327
+ # loss = (neg_weight * patch_probs - pos_weight * counts * torch.log(patch_probs + 1e-6))
328
+ loss = (neg_weights * patch_probs - pos_weights * counts * torch.log(patch_probs + 1e-6))
329
+ # print("patch_probs: ", patch_probs)
330
+ # print("counts: ", counts)
331
+ # print("pos_weights: ", pos_weights)
332
+ # print("neg_weights: ", neg_weights)
333
+ # print("loss: ", loss)
334
+ loss = loss.sum()
335
+
336
+ # # Print loss info
337
+ # print(f"Step {step:03d}: CLIP Loss = {loss.item():.4f}")
338
+
339
+ # # Backprop and update
340
+ # NOTE: No need for scaler since not using mixed precision (i.e. autocast)
341
+ # optimizer.zero_grad()
342
+ # scaler.scale(loss).backward()
343
+ # scaler.step(optimizer)
344
+ # scaler.update()
345
+ optimizer.zero_grad()
346
+ loss.backward()
347
+ optimizer.step()
348
+
349
+ self.tta_time = time.time() - start_time
350
+ # print("self.tta_time: ", self.tta_time)
351
+
352
+ # Visualization every 'num_viz_steps' steps (if enabled)
353
+ if (step + 1) % num_viz_steps == 0 and viz_heatmap:
354
+ # Visualize only the first sample in the batch
355
+ # self.generate_heatmap(img[0], imo[0])
356
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
357
+ self.visualize_heatmap(
358
+ step=step,
359
+ img_path_viz=self.img_paths[0], # Viz 1st image
360
+ imo_path_viz=self.imo_path,
361
+ patch_idx_viz=patch_idx,
362
+ patch_is_pos=patch_is_pos,
363
+ species_name=self.species_name
364
+ )
365
+
366
+ # Save final heatmap after TTA steps
367
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
368
+
369
+
370
+ def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
371
+
372
+ # Query Embeds
373
+ if modality == "image":
374
+ # print("~Image modality")
375
+ query_embeds, *_ = self.model_local.bio_model(img) # (batch_size, proj_dim)
376
+ if query_embeds.shape[0] > 1:
377
+ query_embeds = query_embeds.mean(dim=0, keepdim=True) # (1, proj_dim)
378
+ elif modality == "text" or (modality == "combined" and img.shape[0] == 0):
379
+ # print("~Text modality")
380
+ txt_tokenized = self.tokenizer(txt).to(imo.device)
381
+ _, query_embeds, _ = self.model_local.bio_model(text=txt_tokenized)
382
+ elif modality == "sound" or (modality == "combined" and img.shape[0] == 0):
383
+ # print("~Sound modality")
384
+ if sound == None:
385
+ print("!!!! Sound modality requires sound input !!!")
386
+ exit(1)
387
+ unnormalized_audio_embeds = self.sound_model.audio_encoder(sound)
388
+ query_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
389
+ elif modality == "combined":
390
+ # print("~Combined modality")
391
+ img_embeds, *_ = self.model_local.bio_model(img)
392
+ if img_embeds.shape[0] > 1:
393
+ img_embeds = img_embeds.mean(dim=0, keepdim=True) # (1, proj_dim)
394
+ txt_tokenized = self.tokenizer(txt).to(imo.device)
395
+ _, txt_embeds, _ = self.model_local.bio_model(text=txt_tokenized)
396
+ query_embeds = (img_embeds + txt_embeds) / 2
397
+ else:
398
+ raise ValueError("Invalid modality")
399
+
400
+ return query_embeds
401
+
402
+
403
+ def generate_heatmap(self, img, imo, txt, sound=None, modality="image"):
404
+
405
+ start_time = time.time()
406
+
407
+ # Satellite encoder outputs
408
+ imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state
409
+ # Apply custom projection
410
+ imo_embeds = self.model_local.visual_projection_custom(imo_embeds)
411
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
412
+ # Remove batch dimension -> (num_tokens, proj_dim)
413
+ imo_embeds = imo_embeds.squeeze(0)
414
+ self.patch_embeds = imo_embeds.clone()[1:].cpu().detach().numpy()
415
+
416
+ # # Ground image embedding (bio CLIP model)
417
+ # img_embeds, *_ = self.model_local.bio_model(img)
418
+ query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality)
419
+
420
+ # Same logit scale as in SatBind
421
+ logit_scale = self.model_local.logit_scale.exp()
422
+ sim = query_embeds @ imo_embeds.t() * logit_scale
423
+ # Sigmoid to get similarity scores
424
+ scores = sim.t().sigmoid() # (num_tokens, 1)
425
+
426
+ # Exclude [CLS] token at index 0
427
+ score_no_cls = scores[1:].squeeze() # shape: (num_tokens-1,)
428
+ num_tokens = score_no_cls.shape[0]
429
+ side_dim = int(num_tokens**0.5)
430
+ sim_scores = score_no_cls.reshape(side_dim, side_dim).clone()
431
+ sim_scores = sim_scores.cpu().detach().numpy()
432
+
433
+ self.clip_inference_time = time.time() - start_time
434
+ # print("self.clip_inference_time: ", self.clip_inference_time)
435
+
436
+ # Gausian Smoothing
437
+ if self.blur_kernel != (0,0):
438
+ sim_scores = cv2.GaussianBlur(sim_scores, self.blur_kernel, 0)
439
+
440
+ # Normalize to expectation
441
+ self.heatmap_unnormalized = sim_scores
442
+ scale = len(self.target_positions) / (self.heatmap_unnormalized.sum())
443
+ self.heatmap_unnormalized *= scale
444
+ if self.heatmap_unnormalized_initial is None:
445
+ self.heatmap_unnormalized_initial = self.heatmap_unnormalized.copy()
446
+ # print("self.heatmap_unnormalized.sum(): ", self.heatmap_unnormalized.sum())
447
+
448
+ # Standard normalization to (0,1)
449
+ # self.heatmap = score_no_cls.reshape(side_dim, side_dim).clone()
450
+ # self.heatmap = self.heatmap.cpu().detach().numpy()
451
+ self.heatmap = sim_scores.copy()
452
+ self.heatmap = (self.heatmap - self.heatmap.min()) / (self.heatmap.max() - self.heatmap.min()) # normalize heatmap
453
+
454
+
455
+ def visualize_heatmap(
456
+ self,
457
+ step: int,
458
+ img_path_viz: str,
459
+ imo_path_viz: str,
460
+ patch_idx_viz: torch.Tensor,
461
+ patch_is_pos: list,
462
+ species_name: str
463
+ ):
464
+ """
465
+ Visualization function that plots the ground image, satellite image with
466
+ highlighted patch, and the learned heatmap.
467
+
468
+ :param step: Current TTA step (for labeling the plots).
469
+ :param img_path_viz: File path to the ground image.
470
+ :param imo_path_viz: File path to the satellite image.
471
+ :param patch_idx_viz: The patch index (tensor) being highlighted.
472
+ :param species_text: List of species text labels (from dataset).
473
+ :param species_idx_viz: Index of the species to visualize text for.
474
+ :param img_viz: The ground image tensor (3, H, W).
475
+ :param imo_viz: The satellite image tensor (3, H, W).
476
+ """
477
+
478
+ # Switch off gradients for visualization
479
+ with torch.no_grad():
480
+ side_dim = self.heatmap.shape[0]
481
+
482
+ # -----------------------------------------------------------------
483
+ # Highlight the patch in the satellite image
484
+ sat_img_orig = Image.open(imo_path_viz)
485
+ sat_highlight = np.array(
486
+ self.dataset.debug_imo_viz_transform(sat_img_orig.copy())
487
+ )
488
+
489
+ for idx, patch_idx in enumerate(patch_idx_viz):
490
+
491
+ # Because patch_idx includes the [CLS] offset, subtract 1
492
+ patch_idx_actual = patch_idx - 1
493
+
494
+ # Get dimensions (H x W)
495
+ H, W = sat_highlight.shape[0], sat_highlight.shape[1]
496
+
497
+ # Number of patches in each dimension
498
+ patches_per_col = W // self.patch_size
499
+ patches_per_row = H // self.patch_size
500
+
501
+ # Determine row/col in the patch grid
502
+ patch_row = patch_idx_actual // patches_per_col
503
+ patch_col = patch_idx_actual % patches_per_row
504
+
505
+ # Pixel boundaries
506
+ x_start = patch_col * self.patch_size
507
+ x_end = (patch_col + 1) * self.patch_size
508
+ y_start = patch_row * self.patch_size
509
+ y_end = (patch_row + 1) * self.patch_size
510
+
511
+ # Fill patch area with a grey color
512
+ # if sat_highlight.dtype == np.uint8:
513
+ # grey_value = 128
514
+ # else:
515
+ # grey_value = 0.5
516
+ # sat_highlight[y_start:y_end, x_start:x_end, :] = grey_value
517
+ # Blue color for positive patches (transparent)
518
+ if patch_is_pos[idx]:
519
+ sat_highlight[y_start:y_end, x_start:x_end, 0] = 0
520
+ sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
521
+ sat_highlight[y_start:y_end, x_start:x_end, 2] = 255
522
+ # Red color for negative patches (transparent)
523
+ else:
524
+ sat_highlight[y_start:y_end, x_start:x_end, 0] = 255
525
+ sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
526
+ sat_highlight[y_start:y_end, x_start:x_end, 2] = 0
527
+
528
+
529
+ # -----------------------------------------------------------------
530
+ # Plot results
531
+ fig, axes = plt.subplots(1, 3, figsize=(12, 6))
532
+ fig.suptitle(f"Query: {species_name}")
533
+
534
+ # Ground image
535
+ img_orig = Image.open(img_path_viz)
536
+ axes[0].imshow(img_orig)
537
+ axes[0].set_title("Ground Image")
538
+ axes[0].axis("off")
539
+
540
+ # Satellite image
541
+ axes[1].imshow(sat_highlight)
542
+ axes[1].set_title("Sat Image")
543
+ axes[1].axis("off")
544
+
545
+ # Heatmap
546
+ heatmap_np = self.heatmap_unnormalized
547
+ im = axes[2].imshow(heatmap_np, cmap="viridis")
548
+ axes[2].set_title(
549
+ f"Heatmap at TTA Step {step:03d} ({side_dim}x{side_dim})"
550
+ )
551
+ axes[2].axis("off")
552
+ fig.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
553
+
554
+ plt.tight_layout()
555
+ plt.show()
556
+
557
+
558
+ if __name__ == "__main__":
559
+
560
+ # # (CHANGE ME!) PARAMS: VAL
561
+ # clip_seg_tta = ClipSegTTA(
562
+ # img_dir="/mnt/hdd/inat2021_ds/inat21",
563
+ # imo_dir="/mnt/hdd/inat2021_ds/sat_test_jpg_512px",
564
+ # json_path="/mnt/hdd/inat2021_ds/inat21_val.json",
565
+ # sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_val.json",
566
+ # sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/filtered_mapping_sat_to_img_ids_val.json",
567
+ # patch_size=14,
568
+ # sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
569
+ # sample_index = 45,
570
+ # batch_size=1,
571
+ # num_workers=1,
572
+ # device="cuda",
573
+ # )
574
+
575
+ # Run TTA on sample index 45, for 10 steps,
576
+ # visualizing every 2 steps, with heatmap enabled:
577
+
578
+ # # Step 1
579
+ # print("Executing step 1...")
580
+ # patch_indices = [422, 204, 45]
581
+ # patch_is_pos = [True, False, False]
582
+ # clip_seg_tta.execute_tta(patch_indices, patch_is_pos, tta_steps=10, neg_sample_weight_scale=1.0, num_viz_steps=2, viz_heatmap=True)
583
+
584
+ # # Step 2
585
+ # print("Executing step 2...")
586
+ # # patch_indices = [185, 89, 30] # for 256px
587
+ # patch_indices = [422, 204, 45, 423] # for 512px
588
+ # patch_is_pos = [True, False, False, True]
589
+ # clip_seg_tta.execute_tta(patch_indices, patch_is_pos, tta_steps=10, neg_sample_weight_scale=1.0, num_viz_steps=2, viz_heatmap=True)
590
+
591
+
592
+ # # (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
593
+ # clip_seg_tta = ClipSegTTA(
594
+ # img_dir="/mnt/hdd/inat2021_ds/inat21",
595
+ # imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
596
+ # json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
597
+ # sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
598
+ # # sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/filtered_mapping_sat_to_img_ids_train.json",
599
+ # sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json",
600
+ # patch_size=14,
601
+ # sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
602
+ # sample_index = 99,
603
+ # batch_size=1,
604
+ # num_workers=1,
605
+ # device="cuda",
606
+ # )
607
+
608
+ # # (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
609
+ # clip_seg_tta = ClipSegTTA(
610
+ # img_dir="/mnt/hdd/inat2021_ds/inat21",
611
+ # imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
612
+ # json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
613
+ # # sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
614
+ # sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json",
615
+ # sound_data_path='/mnt/hdd/inat2021_ds/2_OTHERS/sound_train',
616
+ # patch_size=14,
617
+ # sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_180325_CLIP-L-336/satbind-epoch=02-val_loss=2.46-BACKUP.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt
618
+ # sound_checkpoint_path = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_NOT_VAL_OUT_TAX_FILTER_TGT_ONLY_v2_130625/soundbind-epoch=19-val_loss=3.96_BACKUP.ckpt",
619
+ # sample_index = 8, # 5,6,8
620
+ # batch_size=1,
621
+ # num_workers=1,
622
+ # device="cuda",
623
+ # sat_to_img_ids_json_is_train_dict=False,
624
+ # )
625
+
626
+ # (CHANGE ME!) PARAMS: TRAIN w/ TARGETS
627
+ clip_seg_tta = ClipSegTTA(
628
+ img_dir="/mnt/hdd/inat2021_ds/inat21",
629
+ imo_dir="/mnt/hdd/inat2021_ds/sat_train_jpg_512px",
630
+ json_path="/mnt/hdd/inat2021_ds/inat21_train.json",
631
+ # sat_filtered_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json",
632
+ sat_to_img_ids_json_path="/mnt/hdd/inat2021_ds/2_OTHERS/sound_train/sound_image_pairs_filtered_v3_with_in_domain_taxs_150625/val_in_with_sound_ids_v3.json",
633
+ sound_data_path='/mnt/hdd/inat2021_ds/2_OTHERS/sound_train',
634
+ patch_size=14,
635
+ sat_checkpoint_path="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_180325_CLIP-L-336/satbind-epoch=02-val_loss=2.46-BACKUP.ckpt", # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt
636
+ sound_checkpoint_path = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_NOT_VAL_OUT_TAX_FILTER_TGT_ONLY_v2_130625/soundbind-epoch=19-val_loss=3.96_BACKUP.ckpt",
637
+ sample_index = 8, # 5,6,8
638
+ batch_size=1,
639
+ num_workers=1,
640
+ device="cuda",
641
+ sat_to_img_ids_json_is_train_dict=False,
642
+ initial_modality="sound"
643
+ )
644
+
645
+
646
+ ###########################################
647
+ # Scaling with list
648
+ ###########################################
649
+
650
+ # # print("Executing step 4b...")
651
+ # patch_indices = [422, 45]
652
+ # patch_is_pos = [True, False]
653
+ # # per_sample_weight = [1.0, 1.0]
654
+ # pos_sample_weight = 1.0
655
+ # neg_sample_weight = 1.0
656
+ # clip_seg_tta.execute_tta(
657
+ # patch_indices,
658
+ # patch_is_pos,
659
+ # pos_sample_weight,
660
+ # neg_sample_weight,
661
+ # tta_steps=10,
662
+ # num_viz_steps=2,
663
+ # viz_heatmap=True)
664
+
665
+ # print("Executing step 4b...")
666
+ patch_indices = [422, 32]
667
+ patch_is_pos = [True, False]
668
+ # per_sample_weight = [1.0, 1.0]
669
+ pos_sample_weight = 1.0
670
+ neg_sample_weight = 1.0
671
+ clip_seg_tta.execute_tta(
672
+ patch_indices,
673
+ patch_is_pos,
674
+ pos_sample_weight,
675
+ neg_sample_weight,
676
+ tta_steps=30,
677
+ num_viz_steps=2,
678
+ viz_heatmap=True,
679
+ modality="sound")
680
+
681
+ # # print("Executing step 4b...")
682
+ # patch_indices = [45, 422]
683
+ # patch_is_pos = [True, False]
684
+ # per_sample_weight = [1.0, 1.0]
685
+ # clip_seg_tta.execute_tta(
686
+ # patch_indices,
687
+ # patch_is_pos,
688
+ # tta_steps=20,
689
+ # per_sample_weight=per_sample_weight,
690
+ # num_viz_steps=2,
691
+ # viz_heatmap=True)
692
+
693
+ # # print("Executing step 4b...")
694
+ # patch_indices = [45, 422]
695
+ # patch_is_pos = [2, False]
696
+ # per_sample_weight = [0.1, 1.0]
697
+ # clip_seg_tta.execute_tta(
698
+ # patch_indices,
699
+ # patch_is_pos,
700
+ # tta_steps=20,
701
+ # per_sample_weight=per_sample_weight,
702
+ # num_viz_steps=2,
703
+ # viz_heatmap=True)
704
+
705
+ # # print("Executing step 4b...")
706
+ # patch_indices = [45, 422]
707
+ # patch_is_pos = [True, False]
708
+ # per_sample_weight = [1.0, 0.1]
709
+ # clip_seg_tta.execute_tta(
710
+ # patch_indices,
711
+ # patch_is_pos,
712
+ # tta_steps=10,
713
+ # per_sample_weight=per_sample_weight,
714
+ # num_viz_steps=2,
715
+ # viz_heatmap=True)
716
+
717
+ # # print("Executing step 4b...")
718
+ # patch_indices = [422, 424, 398, 400, 45, 47, 41, 43]
719
+ # patch_is_pos = [True, True, True, True, False, False, False, False]
720
+ # per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
721
+ # clip_seg_tta.execute_tta(
722
+ # patch_indices,
723
+ # patch_is_pos,
724
+ # tta_steps=10,
725
+ # per_sample_weight=per_sample_weight,
726
+ # num_viz_steps=2,
727
+ # viz_heatmap=True)
728
+
729
+ # # print("Executing step 4b...")
730
+ # patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
731
+ # patch_is_pos = [True, True, True, True, False, False, False, False]
732
+ # per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
733
+ # clip_seg_tta.execute_tta(
734
+ # patch_indices,
735
+ # patch_is_pos,
736
+ # tta_steps=10,
737
+ # per_sample_weight=per_sample_weight,
738
+ # num_viz_steps=2,
739
+ # viz_heatmap=True)
740
+
741
+ # # print("Executing step 4b...")
742
+ # patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
743
+ # patch_is_pos = [True, True, True, True, False, False, False, False]
744
+ # per_sample_weight = [1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1]
745
+ # clip_seg_tta.execute_tta(
746
+ # patch_indices,
747
+ # patch_is_pos,
748
+ # tta_steps=10,
749
+ # per_sample_weight=per_sample_weight,
750
+ # num_viz_steps=2,
751
+ # viz_heatmap=True)
752
+
753
+ # # print("Executing step 4b...")
754
+ # patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
755
+ # patch_is_pos = [True, True, True, True, False, False, False, False]
756
+ # per_sample_weight = [0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0]
757
+ # clip_seg_tta.execute_tta(
758
+ # patch_indices,
759
+ # patch_is_pos,
760
+ # tta_steps=10,
761
+ # per_sample_weight=per_sample_weight,
762
+ # num_viz_steps=2,
763
+ # viz_heatmap=True)
764
+
765
+ # # print("Executing step 4b...")
766
+ # patch_indices = [45, 47, 41, 43, 42, 44, 46, 40]
767
+ # patch_is_pos = [True, True, True, True, True, True, True, True]
768
+ # per_sample_weight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
769
+ # clip_seg_tta.execute_tta(
770
+ # patch_indices,
771
+ # patch_is_pos,
772
+ # tta_steps=10,
773
+ # per_sample_weight=per_sample_weight,
774
+ # num_viz_steps=2,
775
+ # viz_heatmap=True)
776
+
777
+ # # print("Executing step 4b...")
778
+ # patch_indices = [45, 47, 41, 43, 422, 424, 398, 400]
779
+ # patch_is_pos = [True, True, True, True, False, False, False, False]
780
+ # per_sample_weight = [1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1]
781
+ # clip_seg_tta.execute_tta(
782
+ # patch_indices,
783
+ # patch_is_pos,
784
+ # tta_steps=10,
785
+ # per_sample_weight=per_sample_weight,
786
+ # num_viz_steps=2,
787
+ # viz_heatmap=True)
788
+
789
+ # # print("Executing step 4b...")
790
+ # patch_indices = [554, 45, 422]
791
+ # patch_is_pos = [True, True, False]
792
+ # per_sample_weight = [1.0, 1.0, 1.0]
793
+ # clip_seg_tta.execute_tta(
794
+ # patch_indices,
795
+ # patch_is_pos,
796
+ # tta_steps=10,
797
+ # per_sample_weight=per_sample_weight,
798
+ # num_viz_steps=2,
799
+ # viz_heatmap=True)
800
+
801
+ # # print("Executing step 4b...")
802
+ # patch_indices = [554, 45, 422]
803
+ # patch_is_pos = [True, True, False]
804
+ # per_sample_weight = [1.0, 1.0, 0.1]
805
+ # clip_seg_tta.execute_tta(
806
+ # patch_indices,
807
+ # patch_is_pos,
808
+ # tta_steps=10,
809
+ # per_sample_weight=per_sample_weight,
810
+ # num_viz_steps=2,
811
+ # viz_heatmap=True)
812
+
813
+ # # print("Executing step 4b...")
814
+ # patch_indices = [554, 45, 422]
815
+ # patch_is_pos = [False, True, False]
816
+ # per_sample_weight = [1.0, 1.0, 1.0]
817
+ # clip_seg_tta.execute_tta(
818
+ # patch_indices,
819
+ # patch_is_pos,
820
+ # tta_steps=10,
821
+ # per_sample_weight=per_sample_weight,
822
+ # num_viz_steps=2,
823
+ # viz_heatmap=True)
824
+
825
+ # # print("Executing step 4b...")
826
+ # patch_indices = [554, 45, 422]
827
+ # patch_is_pos = [False, True, False]
828
+ # per_sample_weight = [0.0, 1.0, 0.8]
829
+ # clip_seg_tta.execute_tta(
830
+ # patch_indices,
831
+ # patch_is_pos,
832
+ # tta_steps=10,
833
+ # per_sample_weight=per_sample_weight,
834
+ # num_viz_steps=2,
835
+ # viz_heatmap=True)
836
+
837
+
838
+ ###########################################
839
+ # Text modality
840
+ # NOTE: Less aggressive vs image modality
841
+ ###########################################
842
+
843
+ # # print("Executing step 4b...")
844
+ # patch_indices = [45, 422]
845
+ # patch_is_pos = [True, False]
846
+ # per_sample_weight = [1.0, 1.0]
847
+ # clip_seg_tta.execute_tta(
848
+ # patch_indices,
849
+ # patch_is_pos,
850
+ # tta_steps=10,
851
+ # per_sample_weight=per_sample_weight,
852
+ # modality="text",
853
+ # num_viz_steps=2,
854
+ # viz_heatmap=True)
855
+
856
+ # # print("Executing step 4b...")
857
+ # patch_indices = [45, 422]
858
+ # patch_is_pos = [True, False]
859
+ # per_sample_weight = [0.1, 1.0]
860
+ # clip_seg_tta.execute_tta(
861
+ # patch_indices,
862
+ # patch_is_pos,
863
+ # tta_steps=10,
864
+ # per_sample_weight=per_sample_weight,
865
+ # modality="text",
866
+ # num_viz_steps=2,
867
+ # viz_heatmap=True)
868
+
869
+ # # print("Executing step 4b...")
870
+ # patch_indices = [45, 422]
871
+ # patch_is_pos = [True, False]
872
+ # per_sample_weight = [1.0, 0.1]
873
+ # clip_seg_tta.execute_tta(
874
+ # patch_indices,
875
+ # patch_is_pos,
876
+ # tta_steps=10,
877
+ # per_sample_weight=per_sample_weight,
878
+ # modality="text",
879
+ # num_viz_steps=2,
880
+ # viz_heatmap=True)
881
+
882
+
883
+ ###########################################
884
+ # Combined modality
885
+ # NOTE: Best of both worlds
886
+ ###########################################
887
+
888
+ # # print("Executing step 4b...")
889
+ # patch_indices = [45, 422]
890
+ # patch_is_pos = [True, False]
891
+ # per_sample_weight = [1.0, 1.0]
892
+ # clip_seg_tta.execute_tta(
893
+ # patch_indices,
894
+ # patch_is_pos,
895
+ # tta_steps=10,
896
+ # per_sample_weight=per_sample_weight,
897
+ # modality="combined",
898
+ # num_viz_steps=2,
899
+ # viz_heatmap=True)
900
+
901
+ # # print("Executing step 4b...")
902
+ # patch_indices = [45, 422]
903
+ # patch_is_pos = [True, False]
904
+ # per_sample_weight = [0.1, 1.0]
905
+ # clip_seg_tta.execute_tta(
906
+ # patch_indices,
907
+ # patch_is_pos,
908
+ # tta_steps=10,
909
+ # per_sample_weight=per_sample_weight,
910
+ # modality="combined",
911
+ # num_viz_steps=2,
912
+ # viz_heatmap=True)
913
+
914
+ # # print("Executing step 4b...")
915
+ # patch_indices = [45, 422]
916
+ # patch_is_pos = [True, False]
917
+ # per_sample_weight = [1.0, 0.1]
918
+ # clip_seg_tta.execute_tta(
919
+ # patch_indices,
920
+ # patch_is_pos,
921
+ # tta_steps=10,
922
+ # per_sample_weight=per_sample_weight,
923
+ # modality="combined",
924
+ # num_viz_steps=2,
925
+ # viz_heatmap=True)
926
+
927
+ ###########################################
928
+ # Real examples
929
+ ###########################################
930
+
931
+
932
+
933
+
Taxabind/Taxabind/SatBind/config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ config = edict()
4
+
5
+ # Pixel level CLIP training
6
+ config.img_dir = '/mnt/hdd/inat2021_ds/inat21'
7
+ config.imo_dir = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_train_jpg_256px, sat_train_jpg_512px
8
+ config.imo_dir_val = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
9
+ config.train_json_path = '/mnt/hdd/inat2021_ds/inat21_train.json'
10
+ config.val_json_path = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
11
+ # config.sat_to_img_ids_train_json_path = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_train.json'
12
+ config.sat_to_img_ids_train_json_path = '/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/SEARCH_DS_filtered_mapping_sat_to_img_ids_train_3-20counts.json'
13
+ config.sat_to_img_ids_val_json_path = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
14
+ # config.filtered_train_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_v2_train.json'
15
+ # config.filtered_val_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
16
+
17
+ # # Pixel level CLIP training (Toy example)
18
+ # config.img_dir = '../scripts/preprocess/expt/fake_img_dataset'
19
+ # config.imo_dir = '../scripts/preprocess/expt/sat_train_dataset'
20
+ # config.imo_dir_val = '../scripts/preprocess/expt/sat_val_dataset'
21
+ # config.train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
22
+ # config.val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json'
23
+ # config.sat_to_img_ids_train_json_path = '../scripts/preprocess/expt/sat_filtered/filtered_mapping_sat_to_img_ids_train_micro.json'
24
+ # config.sat_to_img_ids_val_json_path = '../scripts/preprocess/expt/sat_filtered/filtered_mapping_sat_to_img_ids_val_micro.json'
25
+ # # config.filtered_train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
26
+ # # config.filtered_val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json' # no filter needed
27
+
28
+ # # Image level CLIP training
29
+ # config.img_dir = '/mnt/hdd/inat2021_ds/inat21'
30
+ # config.imo_dir = '/mnt/hdd/inat2021_ds/sat_train_jpg'
31
+ # config.imo_dir_val = '/mnt/hdd/inat2021_ds/sat_test_jpg'
32
+ # config.train_json_path = '/mnt/hdd/inat2021_ds/inat21_filtered_img_clip_train.json'
33
+ # config.val_json_path = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
34
+
35
+ # # Image level CLIP training (Toy example)
36
+ # config.img_dir = '../scripts/preprocess/expt/fake_img_dataset'
37
+ # config.imo_dir = '../scripts/preprocess/expt/sat_train_dataset'
38
+ # config.imo_dir_val = '../scripts/preprocess/expt/sat_val_dataset'
39
+ # config.train_json_path = '../scripts/preprocess/expt/inat2021_train_micro.json'
40
+ # config.val_json_path = '../scripts/preprocess/expt/inat2021_val_micro.json'
41
+
42
+ # batch_size * accumulate_grad_batches * devices = MUST BE CONSTANT (i.e. 256 * 8 * 2 = 4096)
43
+ config.batch_size = 32 # 256
44
+ config.lr = 1e-4 # 1e-4
45
+ config.accumulate_grad_batches = 64
46
+ config.max_epochs = 20
47
+ config.num_workers = 16
48
+ config.devices = 2 # 2
49
+ config.val_check_interval = 0.5
50
+ config.sat_encoder = 'openai/clip-vit-large-patch14-336' # openai/clip-vit-base-patch16, openai/clip-vit-large-patch14-336
51
+ config.patch_size = 14
52
+
53
+ config.save_dir = 'checkpoints'
54
+ config.filename = 'satbind-{epoch:02d}-{val_loss:.2f}'
55
+
56
+ config.locked_tuning = True
57
+
58
+ config.resume_from_checkpoint = False
59
+ config.resume_checkpoint_name = 'satbind-resume'
60
+
61
+ print("config: \n", config)
Taxabind/Taxabind/SatBind/dataset.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+
4
+ from matplotlib import pyplot as plt
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ import json
8
+ import os
9
+ from PIL import Image
10
+ from datetime import datetime
11
+ from torchvision.transforms import v2
12
+ import torch
13
+ import numpy as np
14
+ import glob
15
+
16
+ # Just to import SoundBind (1 dir back)
17
+ import sys
18
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
19
+ from SoundBind.sound_encoder import get_audio_clap
20
+
21
+ class SatNatDataset(Dataset):
22
+ def __init__(self, img_dir, imo_dir, json_path, sat_to_img_ids_json_path, patch_size, mode='train', get_img_path=False, sat_to_img_ids_json_is_train_dict=True, tax_to_filter_val="", sound_data_path=None):
23
+ self.img_dir = img_dir
24
+ self.imo_dir = imo_dir
25
+ self.patch_size = patch_size
26
+ self.get_img_path = get_img_path
27
+ self.mode = mode
28
+ self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
29
+ self.tax_to_filter_val = tax_to_filter_val
30
+ self.sound_data_path = sound_data_path
31
+
32
+ # ADDED
33
+ self.current_epoch = 0
34
+
35
+ self.json = json.load(open(json_path, 'r'))
36
+ # self.sat_filt_json = json.load(open(sat_filtered_json_path, 'r'))
37
+ self.sat_to_img_ids_json_path = json.load(open(sat_to_img_ids_json_path, 'r'))
38
+ self.images = self.json['images']
39
+ self.annot = self.json['annotations']
40
+ for i in range(len(self.images)):
41
+ assert self.images[i]['id'] == self.annot[i]['id']
42
+ self.images[i]['label'] = self.annot[i]['category_id']
43
+ self.filtered_json = [d for d in self.images if d['latitude'] is not None and d['longitude'] is not None]
44
+ self.species_text = list(set([" ".join(d['file_name'].split("/")[1].split("_")[1:]) for d in self.filtered_json]))
45
+
46
+ # self.sat_filtered_json = [d for d in self.sat_filt_json['images'] if d['latitude'] is not None and d['longitude'] is not None]
47
+ # self.sat_paths = {d['id']: str(d['id'])+'_'+str(d['latitude'])+'_'+str(d['longitude'])+'.jpg' for d in self.sat_filtered_json}
48
+ self.inat_json_dict = {
49
+ "images": {img["id"]: img for img in self.images},
50
+ "annotations": {ann["id"]: ann for ann in self.annot},
51
+ }
52
+
53
+ # Expand dict
54
+ self.sat_to_img_ids_tuples = []
55
+ if self.sat_to_img_ids_json_is_train_dict:
56
+ # for i in range(len(self.sat_filtered_json)):
57
+ # id = self.sat_filtered_json[i]['id']
58
+ # sat_id = Path(self.sat_paths[id]).stem # remove file extension
59
+ # img_ids = self.sat_to_img_ids_json_path[sat_id]["img_ids"]
60
+ for sat_key, sat_sample in self.sat_to_img_ids_json_path.items():
61
+ id = sat_sample["id"] # int(sat_key.split("_")[0]) # sat_sample['id']
62
+ sat_path = sat_sample["sat_path"]
63
+ img_ids = sat_sample["img_ids"]
64
+ # img_locs = sat_sample["target_positions"] # NOTE: Not accurate after resize
65
+ for img_id in img_ids:
66
+ self.sat_to_img_ids_tuples.append((id, sat_path, img_id))
67
+ print("len(self.sat_to_img_ids_json_path): ", len(self.sat_to_img_ids_json_path))
68
+ print("len(self.sat_to_img_ids_tuples): ", len(self.sat_to_img_ids_tuples))
69
+ else:
70
+ self.filtered_val_ds_by_tax = [d for d in self.sat_to_img_ids_json_path if self.tax_to_filter_val in d['taxonomy']]
71
+
72
+ if mode == 'train':
73
+ self.img_transform = transforms.Compose([
74
+ transforms.Resize((256, 256)),
75
+ transforms.RandomCrop((224, 224)),
76
+ transforms.RandomHorizontalFlip(0.5),
77
+ transforms.GaussianBlur(5, (0.01, 1.0)),
78
+ transforms.ToTensor(),
79
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
80
+ std=[0.229, 0.224, 0.225])
81
+ ])
82
+ self.imo_transform = transforms.Compose([
83
+ # transforms.Resize((256,256)),
84
+ # transforms.RandomCrop((224, 224)),
85
+ # transforms.RandomHorizontalFlip(0.5),
86
+ # transforms.Resize((224,224)),
87
+ transforms.Resize((336,336)),
88
+ transforms.GaussianBlur(5, (0.01, 1.0)),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
91
+ std=[0.229, 0.224, 0.225])
92
+ ])
93
+ else:
94
+ self.img_transform = transforms.Compose([
95
+ transforms.Resize((256, 256)),
96
+ transforms.CenterCrop((224, 224)),
97
+ transforms.ToTensor(),
98
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
99
+ std=[0.229, 0.224, 0.225])
100
+ ])
101
+ self.imo_transform = transforms.Compose([
102
+ # transforms.Resize((256,256)),
103
+ # transforms.CenterCrop((224, 224)),
104
+ # transforms.Resize((224,224)),
105
+ transforms.Resize((336,336)),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
108
+ std=[0.229, 0.224, 0.225])
109
+ ])
110
+ self.debug_img_viz_transform = transforms.Compose([
111
+ transforms.Resize((256, 256)),
112
+ transforms.CenterCrop((224, 224))
113
+ ])
114
+ self.debug_imo_viz_transform = transforms.Compose([
115
+ # transforms.Resize((224,224))
116
+ transforms.Resize((336,336))
117
+ ])
118
+
119
+ def __len__(self):
120
+ # return len(self.filtered_json)
121
+ # return len(self.sat_filtered_json)
122
+ return len(self.sat_to_img_ids_tuples)
123
+
124
+ def __getitem__(self, idx):
125
+
126
+ # # DEBUG: Visualize original image and IMO pair
127
+ # self.debug_viz_img_imo_orig_pair(idx)
128
+ if not self.sat_to_img_ids_json_is_train_dict:
129
+ print("Json is not dict. Please reformat for training!")
130
+ exit()
131
+
132
+ ## Pixel-level CLIP
133
+ # Get sat img
134
+ # id = self.sat_filtered_json[idx]['id']
135
+ id, sat_path, img_id = self.sat_to_img_ids_tuples[idx]
136
+ imo_path = os.path.join(self.imo_dir, sat_path)
137
+ imo = self.imo_transform(Image.open(imo_path))
138
+ sat_id = Path(sat_path).stem # remove file extension
139
+
140
+ img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
141
+ img = self.img_transform(Image.open(img_path))
142
+
143
+ # # Map lat-lon to pixel in sat img
144
+ sat_min_lon = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["min_lon"]
145
+ sat_min_lat = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["min_lat"]
146
+ sat_max_lon = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["max_lon"]
147
+ sat_max_lat = self.sat_to_img_ids_json_path[sat_id]["sat_bounds"]["max_lat"]
148
+
149
+ img_lon = self.inat_json_dict["images"][img_id]["longitude"]
150
+ img_lat = self.inat_json_dict["images"][img_id]["latitude"]
151
+ row, col = self.latlon_to_pixel(img_lat, img_lon, sat_min_lat, sat_max_lat, sat_min_lon, sat_max_lon, imo.shape[2], imo.shape[1])
152
+
153
+ patch_idx = self.pixel_to_patch_idx(row, col, self.patch_size, imo.shape[2], imo.shape[1])
154
+ patch_idx += 1 # account for [CLS] token at the start of ViT input sequence
155
+
156
+ species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
157
+
158
+ # # # DEBUG: Visualize patch
159
+ # img_debug = np.array(self.debug_img_viz_transform(Image.open(img_path)))
160
+ # imo_debug = np.array(self.debug_imo_viz_transform(Image.open(imo_path)))
161
+ # self.debug_viz_patch(img_debug, imo_debug, patch_idx, self.patch_size, species_text)
162
+
163
+ if self.get_img_path:
164
+ return img_path, imo_path, img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
165
+ else:
166
+ return img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
167
+
168
+
169
+ # NOTE: Image-level CLIP
170
+ # img_path = os.path.join(self.img_dir, self.filtered_json[idx]['file_name'])
171
+ # imo_path = os.path.join(self.imo_dir, self.sat_paths[self.filtered_json[idx]['id']])
172
+ # img = self.img_transform(Image.open(img_path))
173
+ # imo = self.imo_transform(Image.open(imo_path))
174
+ # species_text = " ".join(self.filtered_json[idx]['file_name'].split("/")[1].split("_")[1:])
175
+ # return img, imo, self.filtered_json[idx]['label'], species_text, self.species_text.index(species_text)
176
+
177
+ def latlon_to_pixel(self, lat, lon, lat_min, lat_max, lon_min, lon_max, img_width, img_height):
178
+ lat_res = (lat_max - lat_min) / img_height
179
+ lon_res = (lon_max - lon_min) / img_width
180
+
181
+ col = int(math.floor((lon - lon_min) / lon_res))
182
+ row = int(math.floor((lat_max - lat) / lat_res))
183
+
184
+ return row, col
185
+
186
+ def pixel_to_patch_idx(self, row, col, patch_size, img_width, img_height):
187
+ patch_size_width = patch_size
188
+ patch_size_height = patch_size
189
+ patch_row = row // patch_size_height
190
+ patch_col = col // patch_size_width
191
+ patch_idx = patch_row * (img_width // patch_size) + patch_col
192
+
193
+ return patch_idx
194
+
195
+ def set_epoch(self, epoch):
196
+ self.current_epoch = epoch
197
+
198
+ ###########################################################
199
+
200
+ def debug_viz_img_imo_orig_pair(self, idx):
201
+
202
+ img_path = os.path.join(self.img_dir, self.filtered_json[idx]['file_name'])
203
+ imo_path = os.path.join(self.imo_dir, self.sat_paths[self.filtered_json[idx]['id']])
204
+ # img = self.img_transform(Image.open(img_path))
205
+ # imo = self.imo_transform(Image.open(imo_path))
206
+ # species_text = " ".join(self.filtered_json[idx]['file_name'].split("/")[1].split("_")[1:])
207
+
208
+ # Create a side-by-side plot
209
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5))
210
+
211
+ img_np = Image.open(img_path)
212
+ imo_np = Image.open(imo_path)
213
+
214
+ axes[0].imshow(img_np)
215
+ axes[0].set_title('Image')
216
+ axes[0].axis('off')
217
+
218
+ axes[1].imshow(imo_np)
219
+ axes[1].set_title('IMO')
220
+ axes[1].axis('off')
221
+
222
+ plt.tight_layout()
223
+ plt.show()
224
+
225
+
226
+ def debug_viz_patch(self, img_np, imo_np, patch_idx, patch_size, species_text):
227
+
228
+ # ----- Compute the patch boundaries -----
229
+ # Since your patch_idx has been incremented by 1 (for [CLS]), subtract 1 to get the zero-based index.
230
+ patch_idx_actual = patch_idx - 1
231
+
232
+ # Get the image dimensions from the satellite image (imo).
233
+ # Note: In your code, imo.shape[2] is width and imo.shape[1] is height.
234
+ H, W = imo_np.shape[0], imo_np.shape[1]
235
+
236
+ # Compute the number of patches per row.
237
+ patches_per_col = W // patch_size
238
+ patches_per_row = H // patch_size
239
+
240
+ # Determine the patch's row and column in the grid.
241
+ patch_row = patch_idx_actual // patches_per_col
242
+ patch_col = patch_idx_actual % patches_per_row
243
+
244
+ # Calculate pixel boundaries for the patch.
245
+ x_start = patch_col * patch_size
246
+ x_end = (patch_col + 1) * patch_size
247
+ y_start = patch_row * patch_size
248
+ y_end = (patch_row + 1) * patch_size
249
+
250
+ # ----- Extract the patch from the satellite image -----
251
+ imo_patch = imo_np[y_start:y_end, x_start:x_end, :]
252
+
253
+ # ----- Create a highlighted version of the satellite image -----
254
+ # We will grey out the patch region.
255
+ imo_highlight = imo_np.copy()
256
+
257
+ # Define a grey value based on the image's data type.
258
+ if imo_highlight.dtype == np.uint8:
259
+ grey_value = 128 # for 0-255 images
260
+ else:
261
+ grey_value = 0.5 # for images normalized to [0, 1]
262
+
263
+ # Replace the patch region with the grey color.
264
+ imo_highlight[y_start:y_end, x_start:x_end, :] = grey_value
265
+
266
+ # ----- Plotting the three images side by side -----
267
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
268
+ fig.suptitle(f"Species: {species_text}")
269
+
270
+ # Left: Satellite image with the patch greyed out.
271
+ axes[0].imshow(imo_highlight)
272
+ axes[0].set_title("IMO with Patch Greyed Out")
273
+ axes[0].axis("off")
274
+
275
+ # Middle: Cropped patch from the satellite image.
276
+ axes[1].imshow(imo_patch)
277
+ axes[1].set_title("IMO Patch")
278
+ axes[1].axis("off")
279
+
280
+ # Right: Ground image.
281
+ axes[2].imshow(img_np)
282
+ axes[2].set_title("Ground Image")
283
+ axes[2].axis("off")
284
+
285
+ plt.tight_layout()
286
+ plt.show()
287
+
288
+ ###########################################################
289
+
290
+ def get_search_ds_data(self, idx):
291
+
292
+ if self.sat_to_img_ids_json_is_train_dict:
293
+ print("Json is dict. Please reformat for target search!")
294
+ exit()
295
+
296
+
297
+ # sat_sample = self.sat_to_img_ids_json_path[idx] if self.sat_to_img_ids_json_is_train_dict else self.filtered_val_ds_by_tax[idx]
298
+ bounded_idx = idx % len(self.filtered_val_ds_by_tax)
299
+ if idx >= len(self.filtered_val_ds_by_tax):
300
+ print("??? bounded_idx: ", bounded_idx)
301
+ sat_sample = self.filtered_val_ds_by_tax[bounded_idx]
302
+ target_positions = sat_sample["target_positions"]
303
+ imo_path = os.path.join(self.imo_dir, sat_sample["sat_path"])
304
+ imo = self.imo_transform(Image.open(imo_path))
305
+
306
+ img_paths = []
307
+ imgs = []
308
+ species_texts = []
309
+ for img_id in sat_sample["img_ids"]:
310
+ img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
311
+ img = self.img_transform(Image.open(img_path))
312
+ img_paths.append(img_path)
313
+ imgs.append(img)
314
+
315
+ species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
316
+ species_texts.append(species_text)
317
+ imgs = torch.stack(imgs) # Stack all images into a single tensor
318
+
319
+ if len(set(species_texts)) > 1:
320
+ print("Species mismatch in search dataset!")
321
+ exit()
322
+ else:
323
+ species_name = species_texts[0]
324
+
325
+ gt_mask_name = str(sat_sample["id"]) + "_" + sat_sample["taxonomy"] + ".png" # Saved with 2x '_' by accident
326
+ gt_mask_name = gt_mask_name.replace(" ", "_")
327
+
328
+ # Consider sound if valid
329
+ sounds = []
330
+ sound_ids = []
331
+ # if self.sound_data_path is not None and "sound_ids" in sat_sample:
332
+ # for sound_id in sat_sample["sound_ids"]:
333
+ # sound_path = os.path.join(self.sound_data_path,"sounds_mp3",str(sound_id)+"."+'mp3')
334
+ # sound = get_audio_clap(sound_path) # , sound_format)
335
+ # for k in sound.keys():
336
+ # sound[k] = sound[k].squeeze(0)
337
+ # sounds.append(sound)
338
+ # sound_ids.append(sound_id)
339
+ # sounds = torch.stack(sounds)
340
+
341
+ # NOTE: Cannot stack cos sound format is different (Use first index)
342
+ if self.sound_data_path is not None and "sound_ids" in sat_sample:
343
+ sound_id = sat_sample["sound_ids"][0]
344
+ sound_path = os.path.join(self.sound_data_path,"sounds_mp3",str(sound_id)+"."+'mp3')
345
+ sound = get_audio_clap(sound_path) # , sound_format)
346
+ # NOTE: no need to squeeze if not sampled by Pytorch Lightning
347
+ # for k in sound.keys():
348
+ # sound[k] = sound[k].squeeze(0)
349
+ sounds = [sound]
350
+ sound_ids = [sound_id]
351
+
352
+ ###################################
353
+
354
+ ########################################################################################################
355
+ # # TEMP OVERRIDE:
356
+ # imo_path = "/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg"
357
+ # imo = self.imo_transform(Image.open(imo_path).convert("RGB"))
358
+ # img_paths = ["/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg"]
359
+ # imgs = [self.img_transform(Image.open(img_paths[0]))]
360
+ # imgs = torch.stack(imgs)
361
+ # species_name = "Animalia Chordata Aves Charadriiformes Laridae Larus marinus"
362
+ # gt_mask_name = "Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus"
363
+
364
+ # sounds = []
365
+ # sound_ids = []
366
+ # sound_id = 89758229
367
+ # sound_path = "/home/user/search-tta-demo/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
368
+ # sound = get_audio_clap(sound_path) # , sound_format)
369
+ # # NOTE: no need to squeeze if not sampled by Pytorch Lightning
370
+ # # for k in sound.keys():
371
+ # # sound[k] = sound[k].squeeze(0)
372
+ # sounds = [sound]
373
+ # sound_ids = [sound_id]
374
+
375
+ # # 2020 and beyond
376
+ # target_positions = [
377
+ # # (84, 57), # 2015
378
+ # # (59, 323), # Shouldn't be in water
379
+ # (50, 347), # ADDED: Shouldn't be in water
380
+ # # (54, 332), # In water
381
+ # # (19, 507), # 2014
382
+ # (68, 444), # shifted down
383
+ # (136, 505),
384
+ # (142, 465),
385
+ # (345, 343),
386
+ # (130, 241),
387
+ # (367, 281), # Should be more right actually
388
+ # # (463, 351),
389
+ # # (513, 381), # Out of bound
390
+ # # (405, 162), # 2019
391
+ # (315, 80) # Shifted down-left
392
+ # ]
393
+
394
+ ########################################################################################################
395
+
396
+ return img_paths, imo_path, imgs, imo, sounds, sound_ids, species_name, target_positions, gt_mask_name
Taxabind/Taxabind/SatBind/kmeans_clustering.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as mpatches
7
+ import matplotlib.colors as colors
8
+ from sklearn.cluster import KMeans
9
+ from sklearn.metrics import silhouette_score
10
+ from kneed import KneeLocator
11
+ import scipy.ndimage as ndimage
12
+ from collections import Counter
13
+
14
+ # import matplotlib
15
+ # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
16
+
17
+ class CombinedSilhouetteInertiaClusterer:
18
+ """
19
+ A class that:
20
+ 1) Runs combined silhouette & WGSS (inertia) analysis to find the best k.
21
+ 2) Performs a 2D reshape of patch labels.
22
+ 3) Optionally smooths the cluster map.
23
+ 4) Plots the final clustering (before & after smoothing).
24
+ 5) Returns the *raw* (unsmoothed) KMeans labels as a 1D array.
25
+ """
26
+ def __init__(
27
+ self,
28
+ k_min=1,
29
+ k_max=8,
30
+ k_avg_max=4,
31
+ silhouette_threshold=0.15,
32
+ relative_threshold=0.15,
33
+ random_state=0,
34
+ min_patch_size=5,
35
+ n_smooth_iter=2,
36
+ ignore_label=-1,
37
+ plot=False,
38
+ gifs_dir = "./"
39
+ ):
40
+ """
41
+ Parameters
42
+ ----------
43
+ k_min : int
44
+ Minimum number of clusters for KMeans.
45
+ k_max : int
46
+ Maximum number of clusters for KMeans.
47
+ k_avg_max : int
48
+ Upper bound on k after combining elbow & silhouette if they disagree.
49
+ silhouette_threshold : float
50
+ Minimum silhouette score at k=2 to justify splitting.
51
+ relative_threshold : float
52
+ Minimum % improvement in inertia from k=1→k=2 to justify splitting.
53
+ random_state : int
54
+ RNG seed for KMeans.
55
+ min_patch_size : int
56
+ Patches smaller than this threshold are smoothed.
57
+ n_smooth_iter : int
58
+ Number of smoothing iterations.
59
+ ignore_label : int
60
+ Label to ignore in smoothing step.
61
+ """
62
+ self.k_min = k_min
63
+ self.k_max = k_max
64
+ self.k_avg_max = k_avg_max
65
+ self.silhouette_threshold = silhouette_threshold
66
+ self.relative_threshold = relative_threshold
67
+ self.random_state = random_state
68
+
69
+ self.min_patch_size = min_patch_size
70
+ self.n_smooth_iter = n_smooth_iter
71
+ self.ignore_label = ignore_label
72
+ self.plot = plot
73
+ self.gifs_dir = gifs_dir
74
+
75
+ self.final_k = None
76
+ self.final_labels_1d = None # 1D array of cluster labels (unsmoothed)
77
+ self.smoothed_labels_2d = None # 2D array of labels (smoothed)
78
+ self.kmeans_frame_files = []
79
+
80
+ ##############################
81
+ # Helper functions
82
+ ##############################
83
+ def find_largest_connected_component(self, cluster_map, label):
84
+ """
85
+ Identifies the largest connected component (patch) of a given label in the cluster map.
86
+ Uses **4-connectivity** (N, S, E, W) to define connected regions.
87
+
88
+ Parameters:
89
+ -----------
90
+ cluster_map : ndarray (H, W)
91
+ The 2D label map.
92
+ label : int
93
+ The label whose largest connected patch we want to find.
94
+
95
+ Returns:
96
+ --------
97
+ largest_component_mask : ndarray (H, W)
98
+ A binary mask where `True` indicates pixels belonging to the largest connected component.
99
+ """
100
+ mask = (cluster_map == label)
101
+
102
+ # Label all connected components using **4-connectivity**
103
+ labeled_components, num_components = ndimage.label(
104
+ mask,
105
+ structure=np.array([[0,1,0], [1,1,1], [0,1,0]])
106
+ )
107
+
108
+ if num_components == 0:
109
+ return np.zeros_like(cluster_map, dtype=bool) # No components found
110
+
111
+ # Compute sizes of each connected component
112
+ component_sizes = np.bincount(labeled_components.ravel())[1:] # Ignore background (0)
113
+ largest_component_idx = np.argmax(component_sizes) + 1 # +1 to skip background
114
+
115
+ # Create a mask for the largest component
116
+ largest_component_mask = (labeled_components == largest_component_idx)
117
+ return largest_component_mask
118
+
119
+
120
+ def smooth_small_patches(self, labels_2d, min_patch_size=5, n_iter=1, ignore_label=-1):
121
+ """
122
+ Smooths **only small disconnected patches** (relative to the largest patch) within each cluster.
123
+ Uses **majority voting** in a 4-connected neighborhood (N, S, E, W).
124
+ Ignores neighboring pixels that have the `ignore_label` when computing the majority vote.
125
+ Implements a **tie-breaker**: If there is a tie between the original class and another class,
126
+ the original class is retained.
127
+
128
+ Parameters:
129
+ -----------
130
+ labels_2d : ndarray of shape (H, W)
131
+ Integer labels (e.g., cluster IDs), including an ignore label (-1).
132
+ min_patch_size : int
133
+ Patches smaller than this threshold are considered **isolated** and will be smoothed.
134
+ n_iter : int
135
+ Number of smoothing iterations.
136
+ ignore_label : int
137
+ Label value to be ignored when counting neighbors.
138
+
139
+ Returns:
140
+ --------
141
+ smoothed_labels : ndarray of shape (H, W)
142
+ The label map after selective smoothing.
143
+ """
144
+ H, W = labels_2d.shape
145
+ smoothed = labels_2d.copy()
146
+
147
+ for _ in range(n_iter):
148
+ new_label = smoothed.copy() # Clone where updates will be applied
149
+
150
+ unique_labels = np.unique(smoothed)
151
+ np.random.shuffle(unique_labels) # Process clusters in random order to avoid bias
152
+
153
+ for label in unique_labels:
154
+ if label == ignore_label:
155
+ continue
156
+
157
+ # Find the largest connected component for this label
158
+ largest_component_mask = self.find_largest_connected_component(smoothed, label)
159
+
160
+ # Identify small disconnected patches
161
+ small_patches_mask = (smoothed == label) & (~largest_component_mask)
162
+
163
+ # Label these small patches separately
164
+ labeled_small_patches, num_patches = ndimage.label(
165
+ small_patches_mask,
166
+ structure=np.array([[0,1,0], [1,1,1], [0,1,0]])
167
+ )
168
+
169
+ # Process each small patch individually
170
+ for patch_id in range(1, num_patches + 1):
171
+ patch_mask = (labeled_small_patches == patch_id)
172
+ patch_size = np.sum(patch_mask)
173
+
174
+ if patch_size < min_patch_size:
175
+ # Only smooth small isolated patches
176
+ for i, j in zip(*np.where(patch_mask)):
177
+ # Collect 4-connected neighborhood (N, S, E, W) **excluding ignore_label**
178
+ neighbors = []
179
+ if i > 0 and smoothed[i-1, j] != ignore_label: # North
180
+ neighbors.append(smoothed[i-1, j])
181
+ if i < H - 1 and smoothed[i+1, j] != ignore_label: # South
182
+ neighbors.append(smoothed[i+1, j])
183
+ if j > 0 and smoothed[i, j-1] != ignore_label: # West
184
+ neighbors.append(smoothed[i, j-1])
185
+ if j < W - 1 and smoothed[i, j+1] != ignore_label: # East
186
+ neighbors.append(smoothed[i, j+1])
187
+
188
+ if neighbors: # Avoid empty neighbor list
189
+ neighbor_counts = Counter(neighbors)
190
+ most_common_label, count = neighbor_counts.most_common(1)[0]
191
+
192
+ # Tie-breaker condition
193
+ tied_labels = [lbl for lbl, cnt in neighbor_counts.items() if cnt == count]
194
+ if len(tied_labels) > 1 and smoothed[i, j] in tied_labels:
195
+ # If there's a tie and the original class is involved, keep the original class
196
+ new_label[i, j] = smoothed[i, j]
197
+ else:
198
+ # Otherwise, assign the most common label
199
+ new_label[i, j] = most_common_label
200
+
201
+ smoothed = new_label.copy() # Apply modifications
202
+
203
+ return smoothed
204
+
205
+
206
+ def combined_silhouette_inertia_clustering(
207
+ self,
208
+ X,
209
+ k_min=1,
210
+ k_max=8,
211
+ k_avg_max=4,
212
+ silhouette_threshold=0.2,
213
+ relative_threshold=0.05,
214
+ random_state=0
215
+ ):
216
+ """
217
+ Runs KMeans for k in [k_min..k_max] exactly once each,
218
+ collects silhouette scores & inertias, and returns:
219
+ - best_k
220
+ - labels for best_k
221
+ - silhouette_scores (list or None for k=1)
222
+ - inertias (list)
223
+
224
+ The final k is chosen by both silhouette and elbow (WGSS):
225
+ 1) If #points < 2 -> return k=1 and all-zero labels.
226
+ 2) Compare k=1 vs. k=2 improvement in inertia and silhouette(2).
227
+ 3) If both pass threshold, run k=2..k_max, pick best_k_sil (highest silhouette)
228
+ and best_k_elbow (via KneeLocator). If they differ, take their mean (rounded).
229
+ """
230
+ n_samples = len(X)
231
+ if n_samples < 2:
232
+ return 1, np.zeros(n_samples, dtype=int), [None], [None]
233
+
234
+ # --- Fit once for k=1 ---
235
+ km1 = KMeans(n_clusters=1, random_state=random_state).fit(X)
236
+ inertia_k1 = km1.inertia_ / n_samples
237
+ silhouette_k1 = None # undefined for k=1
238
+ # print(f"k=1, inertia={inertia_k1:.4f}, silhouette=NA")
239
+
240
+ if k_max < 2:
241
+ # If k_max=1, no reason to check further
242
+ return 1, km1.labels_, [silhouette_k1], [inertia_k1]
243
+
244
+ # --- Fit once for k=2 ---
245
+ km2 = KMeans(n_clusters=2, random_state=random_state).fit(X)
246
+ inertia_k2 = km2.inertia_ / n_samples
247
+ sil_k2 = silhouette_score(X, km2.labels_)
248
+ # print(f"k=2, inertia={inertia_k2:.4f}, silhouette={sil_k2:.4f}")
249
+
250
+ relative_improvement = (inertia_k1 - inertia_k2) / inertia_k1
251
+ # print(f"[k=1 → k=2] inertia improvement: {relative_improvement:.4%}, sil(k=2)={sil_k2:.4f}")
252
+
253
+ # If improvement is too small or silhouette is too low => remain at k=1
254
+ if (relative_improvement < relative_threshold) or (sil_k2 < silhouette_threshold):
255
+ print("Choosing k=1 (failed thresholds).")
256
+ return 1, km1.labels_, [silhouette_k1, sil_k2], [inertia_k1, inertia_k2]
257
+
258
+ # --- Otherwise fit k=2..k_max and gather inertias & silhouettes ---
259
+ all_k = range(2, k_max + 1)
260
+ kmeans_models = {}
261
+ inertias = []
262
+ silhouettes = []
263
+
264
+ # We already have k=2
265
+ kmeans_models[2] = km2
266
+ inertias.append(inertia_k2)
267
+ silhouettes.append(sil_k2)
268
+
269
+ for k in range(3, k_max + 1):
270
+ km = KMeans(n_clusters=k, random_state=random_state).fit(X)
271
+ kmeans_models[k] = km
272
+
273
+ norm_inertia = km.inertia_ / n_samples
274
+ inertias.append(norm_inertia)
275
+
276
+ # If k>n_samples, silhouette_score is meaningless, but in normal usage k<<n_samples
277
+ sil_val = silhouette_score(X, km.labels_) if k <= n_samples else -1
278
+ silhouettes.append(sil_val)
279
+
280
+ # print(f"k={k}, inertia={norm_inertia:.4f}, silhouette={sil_val:.4f}")
281
+
282
+ # (a) Silhouette-based best_k_sil
283
+ best_idx_sil = np.argmax(silhouettes)
284
+ best_k_sil = best_idx_sil + 2 # offset since silhouettes[0] is k=2
285
+
286
+ # (b) Inertia-based best_k_elbow
287
+ k_candidates = np.arange(2, k_max + 1)
288
+ if len(k_candidates) == 1:
289
+ best_k_elbow = 2
290
+ else:
291
+ kn = KneeLocator(k_candidates, inertias, curve="convex", direction="decreasing")
292
+ best_k_elbow = kn.elbow
293
+ if best_k_elbow is None:
294
+ print("No elbow found => default to k=1.")
295
+ best_k_elbow = 1 # fallback
296
+
297
+ print(f"Silhouette-based best_k={best_k_sil}, elbow-based best_k={best_k_elbow}")
298
+
299
+ # Combine if there's disagreement
300
+ if best_k_sil == best_k_elbow:
301
+ final_k = max(1, min(best_k_sil, k_avg_max)) # best_k_sil
302
+ else:
303
+ avg_k = 0.5 * (best_k_sil + best_k_elbow)
304
+ final_k = int(math.ceil(avg_k))
305
+ final_k = max(1, min(final_k, k_avg_max))
306
+ # print(f"Disagreement => taking final_k=ceil((sil={best_k_sil} + elbow={best_k_elbow})/2) = {final_k}, capped by k_avg_max={k_avg_max}")
307
+ assert (final_k <= k_avg_max), f"Final k={final_k} is greater than k_avg_max={k_avg_max}"
308
+
309
+ # Get final labels from the chosen KMeans model
310
+ if final_k == 1:
311
+ final_labels = km1.labels_
312
+ else:
313
+ final_labels = kmeans_models[final_k].labels_
314
+
315
+ return final_k, final_labels, [silhouette_k1] + silhouettes, [inertia_k1] + inertias
316
+
317
+
318
+ def compute_region_statistics(self, label_map, heatmap, visited_indices, episode_num=0, step_num=0):
319
+ """
320
+ Computes region statistics for the current smoothed label map.
321
+
322
+ Parameters
323
+ ----------
324
+ heatmap : ndarray, shape (H, W)
325
+ A 2D array of probabilities (or intensities) for each patch.
326
+ Must match the shape of `self.smoothed_labels_2d`.
327
+ visited_indices : list or array of int
328
+ A list of flat indices (0-based) indicating which patches
329
+ the robot has visited. The flat index is equivalent to
330
+ r * W + c for patch at row r, column c.
331
+
332
+ Returns
333
+ -------
334
+ region_dict : dict
335
+ A dictionary keyed by cluster label ID, with values:
336
+ {
337
+ 'num_patches': int,
338
+ 'patches_visited': int,
339
+ 'expectation': float
340
+ }
341
+ """
342
+ # Flatten the cluster map and the heatmap to handle indexing uniformly
343
+ # label_map_2d = label_map # self.smoothed_labels_2d
344
+ # H, W = label_map_2d.shape
345
+ label_map_2d = self.smoothed_labels_2d
346
+ label_map_1d = self.smoothed_labels_2d.ravel()
347
+ heatmap_1d = heatmap.ravel()
348
+
349
+ # Identify unique labels (excluding ignore_label if present)
350
+ unique_labels = np.unique(label_map_1d)
351
+ region_dict = {}
352
+ for lbl in unique_labels:
353
+ if lbl == self.ignore_label: # skip ignore_label
354
+ continue
355
+ region_dict[lbl] = {
356
+ 'num_patches': 0,
357
+ 'patches_visited': 0,
358
+ 'expectation': 0.0
359
+ }
360
+
361
+ # Accumulate totals for all patches
362
+ total_patches = len(label_map_1d)
363
+ for i in range(total_patches):
364
+ lbl = label_map_1d[i]
365
+ if lbl == self.ignore_label:
366
+ continue
367
+ region_dict[lbl]['num_patches'] += 1
368
+ region_dict[lbl]['expectation'] += float(heatmap_1d[i])
369
+
370
+ # # Exponential distribution (waiting time) = num_patches / expected_num_tgts
371
+ for lbl in region_dict:
372
+ region_dict[lbl]['expectation'] = region_dict[lbl]['num_patches'] / region_dict[lbl]['expectation']
373
+
374
+ # Count only unique visited patches by converting to a set.
375
+ unique_visited = set(visited_indices)
376
+ for vi in unique_visited:
377
+ if vi < 0 or vi >= total_patches:
378
+ continue # skip out-of-bounds indices
379
+ lbl = label_map_1d[vi]
380
+ # lbl = self.get_label_id(vi)
381
+ if lbl == self.ignore_label:
382
+ continue
383
+ region_dict[lbl]['patches_visited'] += 1
384
+
385
+ if self.plot:
386
+ self.plot_cluster_map(label_map_2d, heatmap, visited_indices, region_dict, episode_num, step_num)
387
+
388
+ return region_dict
389
+
390
+ def plot_cluster_map(self, cluster_map, heatmap, path_taken, region_stats_dict, episode_num, step_num, cmap='tab20'):
391
+
392
+ # 4) Plot (side-by-side) if requested
393
+ fig, axes = plt.subplots(1, 3, figsize=(12, 6))
394
+
395
+ axes[0].imshow(cluster_map, cmap='tab20')
396
+ axes[0].set_title(f"Raw KMeans Clusters")
397
+ axes[0].axis('off')
398
+
399
+ axes[1].imshow(heatmap, cmap="viridis")
400
+ axes[1].set_title("Heatmap")
401
+ axes[1].axis('off')
402
+
403
+ axes[2].imshow(cluster_map, cmap='tab20')
404
+ axes[2].set_title("Raw KMeans Clusters")
405
+ axes[2].axis('off')
406
+
407
+ path_rows, path_cols = [], []
408
+ for i, idx in enumerate(path_taken):
409
+ rr = idx // cluster_map.shape[1]
410
+ cc = idx % cluster_map.shape[1]
411
+ path_rows.append(rr)
412
+ path_cols.append(cc)
413
+ axes[2].plot(path_cols, path_rows, c="r", linewidth=2)
414
+ axes[2].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
415
+ axes[2].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
416
+
417
+ # Create legend patches for each region.
418
+ # We assume region labels are non-negative integers.
419
+ unique_labels = sorted(region_stats_dict.keys())
420
+ # For normalization, use max label value (or 1 if max==0)
421
+ max_label = max(unique_labels) if unique_labels else 1
422
+ cm = plt.get_cmap(cmap)
423
+ legend_patches = []
424
+ for lbl in unique_labels:
425
+ # Normalize the label to [0,1] for colormap lookup.
426
+ norm_value = lbl / max_label if max_label > 0 else 0.5
427
+ color = cm(norm_value)
428
+ patch = mpatches.Patch(color=color, label=f"R{lbl}")
429
+ legend_patches.append(patch)
430
+
431
+ # Add legends to both subplots.
432
+ axes[0].legend(handles=legend_patches, title="Regions", loc='upper right')
433
+ axes[2].legend(handles=legend_patches, title="Regions", loc='upper right')
434
+
435
+ # Build the legend text for each region using the provided format:
436
+ # "R{label}: patches={num_patches}, E={expectation:.3f}, visited={num_visited}"
437
+ legend_lines = []
438
+ for label, stats in region_stats_dict.items():
439
+ line = f"R{label}: patches={stats['num_patches']}, E={stats['expectation']:.3f}, visited={stats['patches_visited']}"
440
+ legend_lines.append(line)
441
+ legend_text = "\n".join(legend_lines)
442
+
443
+ # Add the legend text as a subtitle at the bottom of the figure
444
+ # Using fig.text to place the text at the center bottom with a background box
445
+ fig.text(0.5, 0.05, legend_text, ha='center', va='bottom', fontsize=10,
446
+ bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
447
+
448
+ # Adjust layout to reserve space for the legend text
449
+ plt.tight_layout()
450
+ plt.subplots_adjust(bottom=0.1) # Prevent overlap
451
+ # plt.tight_layout(rect=[0, 0.05, 1, 1])
452
+ # plt.show()
453
+
454
+
455
+ # gifs_folder = '/home/user/VLM-Search/inference/test_results/gifs'
456
+ if not os.path.exists(self.gifs_dir):
457
+ os.makedirs(self.gifs_dir)
458
+
459
+ plt.savefig(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png'.format(dpi=150))
460
+ self.kmeans_frame_files.append(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png')
461
+ plt.close()
462
+
463
+
464
+ # ------ OTHER HELPER FUNCTIONS -------#
465
+ def get_label_id(self, patch_idx):
466
+ """
467
+ Given a flattened index (row-major order) within the 24×24 grid,
468
+ return the average probability of the corresponding merged region.
469
+
470
+ :param patch_idx: Flattened index (0 to H×W-1).
471
+ :return: The region-average probability, or 0.0 if the patch
472
+ is out of bounds or not in a labeled region.
473
+ """
474
+ # Guard: check we have a merged label map
475
+ if self.smoothed_labels_2d is None:
476
+ raise ValueError("Merged markers do not exist. "
477
+ "Please call `incorporate_kmeans_boundaries(...)` first.")
478
+
479
+ return self.smoothed_labels_2d.ravel()[patch_idx]
480
+
481
+
482
+ # ------ OTHER HELPER FUNCTIONS -------#
483
+ def get_probs(self, patch_idx, heatmap):
484
+ """
485
+ Given a flattened index (row-major order) within the 24×24 grid,
486
+ return the average probability of the corresponding merged region.
487
+
488
+ :param patch_idx: Flattened index (0 to H×W-1).
489
+ :return: The region-average probability, or 0.0 if the patch
490
+ is out of bounds or not in a labeled region.
491
+ """
492
+
493
+ return heatmap.ravel()[patch_idx]
494
+
495
+ ##############################
496
+ # Main functions
497
+ ##############################
498
+ def fit_predict(self, patch_embeds, map_shape):
499
+ """
500
+ Main entry point.
501
+
502
+ Parameters
503
+ ----------
504
+ patch_embeds : ndarray, shape (N, D)
505
+ The satellite patch embeddings to be clustered.
506
+ map_shape : tuple of (H, W)
507
+ The 2D layout of patches. Must satisfy H * W = N.
508
+ do_plot : bool
509
+ Whether to display the clustering maps before and after smoothing.
510
+
511
+ Returns
512
+ -------
513
+ labels_1d : ndarray, shape (N,)
514
+ The final cluster assignments from KMeans (UNSMOOTHED).
515
+ """
516
+ # 1) Run combined silhouette & inertia
517
+ best_k, final_labels, silhouettes, inertias = self.combined_silhouette_inertia_clustering(
518
+ X=patch_embeds,
519
+ k_min=self.k_min,
520
+ k_max=self.k_max,
521
+ k_avg_max=self.k_avg_max,
522
+ silhouette_threshold=self.silhouette_threshold,
523
+ relative_threshold=self.relative_threshold,
524
+ random_state=self.random_state
525
+ )
526
+ self.final_k = best_k
527
+ self.final_labels_1d = final_labels.copy() # store raw labels
528
+
529
+ # 2) Reshape for display
530
+ H, W = map_shape
531
+ cluster_map = final_labels.reshape(H, W)
532
+
533
+ # 3) Apply smoothing
534
+ # cluster_map_smoothed = self.smooth_small_patches(
535
+ # labels_2d=cluster_map,
536
+ # min_patch_size=self.min_patch_size,
537
+ # n_iter=self.n_smooth_iter,
538
+ # ignore_label=self.ignore_label
539
+ # )
540
+ cluster_map_smoothed = cluster_map
541
+ self.smoothed_labels_2d = cluster_map_smoothed.copy()
542
+
543
+ # 5) Return the smoothed 2D labels
544
+ return self.smoothed_labels_2d
545
+
546
+
547
+ if __name__ == "__main__":
548
+
549
+ file_path="./expt/patch_embeds_99.npy"
550
+ patch_embeds = np.load(file_path)
551
+ map_shape = (int(np.sqrt(patch_embeds.shape[0])), int(np.sqrt(patch_embeds.shape[0])))
552
+
553
+ # 1) Load a 24x24 heatmap
554
+ file_path="./expt/seg_mask_step0_v1.npy"
555
+ heatmap = np.load(file_path)
556
+ heatmap = np.clip(heatmap,0,100)
557
+
558
+ # Normalize heatmap
559
+ num_targets = 19
560
+ scale = num_targets / (heatmap.sum())
561
+ heatmap *= scale
562
+
563
+ # Instantiate our clusterer
564
+ clusterer = CombinedSilhouetteInertiaClusterer(
565
+ k_min=1,
566
+ k_max=8,
567
+ k_avg_max=4,
568
+ silhouette_threshold=0.15,
569
+ relative_threshold=0.15,
570
+ random_state=0,
571
+ min_patch_size=5, # smoothing parameter
572
+ n_smooth_iter=2, # smoothing parameter
573
+ ignore_label=-1
574
+ )
575
+
576
+ # Fit & predict (this will also plot the clusters before & after smoothing)
577
+ smoothed_labels_2d = clusterer.fit_predict(
578
+ patch_embeds=patch_embeds,
579
+ map_shape=map_shape,
580
+ )
581
+
582
+ # 2) Simulate a random walk
583
+ max_r, max_c=heatmap.shape[0]-1, heatmap.shape[1]-1
584
+ steps=100
585
+ current_pos=(0,0)
586
+ moves_8=[(-1,-1),(-1,0),(-1,1),
587
+ ( 0,-1), ( 0,1),
588
+ ( 1,-1),( 1,0), ( 1,1)]
589
+ path_coords=[current_pos]
590
+ # num_found_list=[0]
591
+ for _ in range(steps):
592
+ r,c=current_pos
593
+ dr,dc=random.choice(moves_8)
594
+ nr, nc=r+dr,c+dc
595
+ if 0<=nr<=max_r and 0<=nc<=max_c:
596
+ current_pos=(nr,nc)
597
+ path_coords.append(current_pos)
598
+ # 10% chance for new target
599
+ # num_found_list.append(random.random()<0.1)
600
+ # num_found_list.append(0)
601
+ # num_found_list[0]=1
602
+ # num_found_list[1]=2
603
+ # num_found_list[2]=3
604
+
605
+ # Flatten
606
+ width=heatmap.shape[1]
607
+ flattened_visits=[]
608
+ for (rr,cc) in path_coords:
609
+ idx=rr*width+cc
610
+ flattened_visits.append(idx)
611
+
612
+
613
+ # Input heatmap (for region statistics)
614
+ region_dict = clusterer.compute_region_statistics(smoothed_labels_2d, heatmap, flattened_visits, plot=True)
615
+
616
+ # # save the smoother labels as .npy
617
+ # np.save('smoothed_labels_2d.npy', smoothed_labels_2d)
618
+
619
+ # print("Chosen k:", clusterer.final_k)
620
+ # print("Smoothed labels shape:", smoothed_labels_2d.shape)
Taxabind/Taxabind/SatBind/model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ # Ensure the correct directory is at the front of sys.path
4
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ import open_clip
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ from torch.utils.data import DataLoader
12
+ from transformers import CLIPVisionModelWithProjection
13
+ from dataset import SatNatDataset
14
+
15
+ from pytorch_lightning.loggers import WandbLogger
16
+ from pytorch_lightning.callbacks import ModelCheckpoint
17
+ from config import config
18
+
19
+ ######################################
20
+ class UpdateDatasetEpochCallback(pl.Callback):
21
+ def on_train_epoch_start(self, trainer, pl_module):
22
+ """Called at the start of each training epoch."""
23
+ print("!!!trainer.current_epoch: ", trainer.current_epoch)
24
+ current_epoch = trainer.current_epoch
25
+
26
+ # Access the dataset and set the epoch
27
+ train_loader = trainer.train_dataloader
28
+ # trainer.train_dataloader can be a list in some configs, handle that if needed
29
+ dataset = train_loader.dataset
30
+
31
+ # Now set the epoch:
32
+ if hasattr(dataset, 'set_epoch'):
33
+ dataset.set_epoch(current_epoch)
34
+ ######################################
35
+
36
+
37
+ def create_pairwise_mask(labels):
38
+ labels = labels.reshape(-1)
39
+ num_samples = len(labels)
40
+ pairwise_mask = torch.zeros(num_samples, num_samples).to(labels.device)
41
+
42
+ for i in range(num_samples):
43
+ pairwise_mask[i, :] = (labels == labels[i])
44
+
45
+ return pairwise_mask
46
+
47
+ def clip_loss(similarity: torch.Tensor, label) -> torch.Tensor:
48
+ overhead_img_loss = contrastive_loss(similarity, label)
49
+ ground_img_loss = contrastive_loss(similarity.t(), label.t())
50
+ return 0.5*torch.mean(torch.sum(overhead_img_loss, dim=-1)) + 0.5*torch.mean(torch.sum(ground_img_loss, dim=-1))
51
+
52
+ def contrastive_loss(logits: torch.Tensor, label) -> torch.Tensor:
53
+ gt = create_pairwise_mask(label)
54
+ return -gt*torch.log(logits.softmax(-1)+1e-6)
55
+
56
+ class SatBind(pl.LightningModule):
57
+ def __init__(self, train_dataset, val_dataset, **kwargs):
58
+ super().__init__()
59
+ self.train_dataset = train_dataset
60
+ self.val_dataset = val_dataset
61
+
62
+ #initialize bio CLIP with frozen weights
63
+ self.bio_model, *_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
64
+ if config.locked_tuning:
65
+ for param in self.bio_model.parameters():
66
+ param.requires_grad = False
67
+
68
+ #initialize CLIP with trainable weights
69
+ self.imo_encoder = CLIPVisionModelWithProjection.from_pretrained(config.sat_encoder).train()
70
+ for layer in self.imo_encoder.children():
71
+ if hasattr(layer, 'reset_parameters'):
72
+ layer.reset_parameters()
73
+
74
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
75
+ self.batch_size = kwargs.get('batch_size', config.batch_size)
76
+ self.lr = kwargs.get('lr', config.lr)
77
+
78
+ # # ADDED
79
+ clip_cfg = self.imo_encoder.config
80
+ self.visual_projection_custom = nn.Linear(clip_cfg.hidden_size, 512, bias=False) # clip_cfg.projection_dim)
81
+ # print("clip_cfg.hidden_size: ", clip_cfg.hidden_size)
82
+ # print("clip_cfg.projection_dim: ", clip_cfg.projection_dim)
83
+
84
+ # # Since CLIP's vision projection is not used (prevent pytorch lightning issues)
85
+ # for param in self.imo_encoder.visual_projection.parameters():
86
+ # param.requires_grad = False
87
+ # for param in self.imo_encoder.transformer.parameters(): # If the transformer is unused
88
+ # param.requires_grad = False
89
+
90
+ def forward(self, batch):
91
+ img, imo, label, patch_idx, *_ = batch
92
+ batch_size = img.shape[0]
93
+
94
+ #compute bioclip embeddings
95
+ img_embeds, *_ = self.bio_model(img) # (batch_size, proj_dim)
96
+
97
+ # ## ORIGINAL: compute overhead embeddings
98
+ # imo_embeds = self.imo_encoder(imo).image_embeds # (batch, proj_dim)
99
+ # print("[ORIG] imo_embeds.shape: ", imo_embeds.shape)
100
+ # print("[ORIG] torch.min(imo_embeds): ", torch.min(imo_embeds))
101
+ # print("[ORIG] torch.max(imo_embeds): ", torch.max(imo_embeds))
102
+
103
+ # NEW:
104
+ imo_embeds = self.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
105
+ imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
106
+ # imo_embeds = self.imo_encoder.vision_model.post_layernorm(imo_embeds) # (batch, hidden_dim)
107
+ # imo_embeds = self.imo_encoder.visual_projection(imo_embeds) # (batch_size, proj_dim)
108
+ imo_embeds = self.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
109
+
110
+ return img_embeds, imo_embeds, label
111
+
112
+
113
+ def shared_step(self, batch, return_sim_matrix=False):
114
+
115
+ img_embeds, imo_embeds, label, *_ = self(batch)
116
+ #normalize embeddings
117
+ #img embeds is already normalized
118
+ img_embeds = img_embeds
119
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
120
+ # print("[NORM] torch.min(imo_embeds): ", torch.min(imo_embeds))
121
+ # print("[NORM] torch.max(imo_embeds): ", torch.max(imo_embeds))
122
+
123
+ #exponentiate the log of temperrature
124
+ logit_scale = self.logit_scale.exp()
125
+
126
+ #compute similarity
127
+ img_to_imo_sim = img_embeds @ imo_embeds.t() * logit_scale
128
+
129
+ if return_sim_matrix:
130
+ img_to_imo_sim_copy = img_to_imo_sim.clone().detach()
131
+
132
+ loss = clip_loss(img_to_imo_sim, label)
133
+
134
+ if return_sim_matrix:
135
+ return loss, img_to_imo_sim_copy
136
+ else:
137
+ return loss
138
+
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ loss = self.shared_step(batch)
142
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
143
+ self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
144
+ return loss
145
+
146
+ def validation_step(self, batch, batch_idx):
147
+ loss = self.shared_step(batch)
148
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
149
+ return loss
150
+
151
+ def train_dataloader(self):
152
+ return DataLoader(self.train_dataset,
153
+ batch_size=self.batch_size,
154
+ num_workers=config.num_workers,
155
+ shuffle=True, # True
156
+ persistent_workers=False)
157
+
158
+ def val_dataloader(self):
159
+ return DataLoader(self.val_dataset,
160
+ batch_size=self.batch_size,
161
+ num_workers=config.num_workers,
162
+ shuffle=False,
163
+ persistent_workers=False)
164
+
165
+ def configure_optimizers(self):
166
+ params = self.parameters()
167
+ self.optim = torch.optim.AdamW(params,
168
+ lr=self.lr,
169
+ betas=(0.9,0.98),
170
+ eps=1e-6
171
+ )
172
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
173
+ optimizer=self.optim,
174
+ T_0=20,
175
+ eta_min=1e-6
176
+ )
177
+ return [self.optim], [self.scheduler]
178
+
179
+
180
+ if __name__ == '__main__':
181
+ img_dir = config.img_dir
182
+ imo_dir = config.imo_dir
183
+ imo_dir_val = config.imo_dir_val
184
+ train_json_path = config.train_json_path
185
+ val_json_path = config.val_json_path
186
+ # filtered_train_json_path = config.filtered_train_json_path
187
+ # filtered_val_json_path = config.filtered_val_json_path
188
+ sat_to_img_ids_train_json_path = config.sat_to_img_ids_train_json_path
189
+ sat_to_img_ids_val_json_path = config.sat_to_img_ids_val_json_path
190
+ patch_size = config.patch_size
191
+
192
+ #define dataset
193
+ train_dataset = SatNatDataset(img_dir, imo_dir, train_json_path, sat_to_img_ids_train_json_path, patch_size)
194
+ val_dataset = SatNatDataset(img_dir, imo_dir_val, val_json_path, sat_to_img_ids_val_json_path, patch_size, mode='val')
195
+
196
+ #define model
197
+ model = SatBind(train_dataset=train_dataset, val_dataset=val_dataset)
198
+ torch.cuda.empty_cache()
199
+
200
+ checkpoint = ModelCheckpoint(
201
+ monitor='val_loss',
202
+ dirpath=config.save_dir,
203
+ filename=config.filename,
204
+ mode='min',
205
+ save_top_k=1, # save only the best checkpoint (according to val_loss)
206
+ save_last=True # also save the last checkpoint every epoch
207
+ # last_filename=config.filename + "-LAST"
208
+ )
209
+ checkpoint.CHECKPOINT_NAME_LAST = config.filename + "-LAST" # Rename last checkpoint name
210
+
211
+ trainer = pl.Trainer(
212
+ accelerator='gpu',
213
+ strategy='ddp_find_unused_parameters_true', # ddp (orig), ddp_find_unused_parameters_true (supress pl issues with unused trainable params)
214
+ devices=config.devices,
215
+ max_epochs=config.max_epochs,
216
+ num_nodes=1,
217
+ callbacks=[checkpoint, UpdateDatasetEpochCallback()],
218
+ accumulate_grad_batches=config.accumulate_grad_batches,
219
+ log_every_n_steps=1,
220
+ val_check_interval=config.val_check_interval,
221
+ )
222
+
223
+ if config.resume_from_checkpoint:
224
+ trainer.fit(model, ckpt_path=f"{config.save_dir}/{config.resume_checkpoint_name}.ckpt")
225
+ else:
226
+ trainer.fit(model)
Taxabind/Taxabind/SatBind/watershed_segmentation.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import random
6
+ import math
7
+ import numpy as np
8
+ from collections import Counter
9
+ from scipy import ndimage
10
+
11
+ # import matplotlib
12
+ # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
13
+
14
+ class WatershedBinomial:
15
+ def __init__(
16
+ self,
17
+ N=5.0, # total expected # of targets across entire map
18
+ blur_kernel=(5, 5),
19
+ sure_fg_factor=0.7,
20
+ dilation_kernel_size=(3, 3),
21
+ dilation_iters=3,
22
+ region_size_thresh=10,
23
+ plot_img=False,
24
+ gifs_dir="./gifs",
25
+ colormap=plt.cm.tab10,
26
+ title="Watershed + Binomial Steps"
27
+ ):
28
+ """
29
+ A binomial viewpoint:
30
+ 1) We run watershed on a 24x24 heatmap.
31
+ 2) We scale each region's probability p_i so sum of expected targets across
32
+ all regions = N. Then expected # in region i is E_i = size_i * p_i.
33
+ 3) We track how many targets each region has found so far (#found).
34
+ 4) The average steps to find the next target is size_i/E_i = 1/p_i.
35
+ 5) We store these as class variables so they can be accessed later:
36
+ - region_prob[lab]
37
+ - region_visited_str[lab]
38
+ - region_expected[lab]
39
+ - region_found[lab]
40
+ - region_steps_next[lab]
41
+ 6) The final legend includes all this info for each region.
42
+ """
43
+
44
+ self.N = N
45
+ self.blur_kernel = blur_kernel
46
+ self.sure_fg_factor = sure_fg_factor
47
+ self.dilation_kernel_size = dilation_kernel_size
48
+ self.dilation_iters = dilation_iters
49
+ self.region_size_thresh = region_size_thresh
50
+ self.plot_img = plot_img
51
+ self.gifs_dir = gifs_dir
52
+ self.colormap = colormap
53
+ self.title = title
54
+
55
+ self.raw_heatmap = None
56
+ self.flattened_visits = None
57
+ self.num_found_list = None
58
+
59
+ # --------------------------
60
+ # Intermediate images
61
+ # --------------------------
62
+ self.input_heatmap_24x24 = None
63
+ self.padded_heatmap_26x26 = None
64
+ self.blurred_heatmap_26x26 = None
65
+ self.binary_mask_26x26 = None
66
+ self.dist_transform_26x26 = None
67
+ self.sure_fg_mask_26x26 = None
68
+ self.watershed_markers_26x26 = None
69
+ self.color_watershed_raw_26x26 = None
70
+
71
+ # --------------------------
72
+ # Final label map + overlays
73
+ # --------------------------
74
+ self.final_markers_24x24 = None
75
+ self.final_output_no_vis_24x24 = None
76
+ self.final_output_with_vis_24x24 = None
77
+
78
+ # --------------------------
79
+ # Class variables of interest
80
+ # --------------------------
81
+ # region_prob[lab] -> scaled probability p_i
82
+ # region_visited_str[lab]-> visited fraction as "visited_count/size_i"
83
+ # region_expected[lab] -> E_i = size_i * p_i
84
+ # region_found[lab] -> number of targets found so far
85
+ # region_steps_next[lab] -> size_i/E_i (1/p_i)
86
+ self.region_prob = {}
87
+ self.visited_num = {}
88
+ self.region_visited_str = {}
89
+ self.region_centroids = {}
90
+ self.region_expected = {}
91
+ self.region_found = {}
92
+ self.region_steps_next = {}
93
+
94
+ # --------------------------
95
+ # For MERGED segmentation with k-means boundaries
96
+ # We'll store them in separate variables
97
+ # --------------------------
98
+ self.merged_markers_24x24 = None
99
+ self.merged_output_no_vis = None
100
+ self.merged_output_with_vis = None
101
+
102
+ self.merged_region_prob = {}
103
+ self.merged_visited_num = {}
104
+ self.merged_region_visited_str = {}
105
+ self.merged_region_expected = {}
106
+ self.merged_region_found = {}
107
+ self.merged_region_steps_next = {}
108
+ self.merged_region_centroids = {}
109
+
110
+ # OTHERS
111
+ self.segmentor_frame_files = []
112
+ self.kmeans_merged_frame_files = []
113
+
114
+ # Persistent label-color mapping (Preassign colors for labels)
115
+ self.label_to_color = {}
116
+ for i in range(1, 101):
117
+ color = np.array(self.colormap(i % 10)[:3]) * 255
118
+ self.label_to_color[i] = color.astype(np.uint8)
119
+
120
+ # --------------------------
121
+ # Public Entry
122
+ # --------------------------
123
+
124
+ def run(self, raw_heatmap, flattened_visits, num_found_list, step=0):
125
+ """
126
+ 1) Watershed pipeline
127
+ 2) Scale probabilities => sum_i #expected = N
128
+ 3) Build final overlays & store region stats
129
+ 4) Plot subplots
130
+ """
131
+ self.raw_heatmap = raw_heatmap
132
+ self.flattened_visits = flattened_visits
133
+ self.num_found_list = num_found_list
134
+ self.input_heatmap_24x24 = self.raw_heatmap.copy()
135
+
136
+ # Step A) Watershed + merges
137
+ self._watershed_on_padded()
138
+ self.final_markers_24x24 = self._postprocess_cropped_markers()
139
+
140
+ # Step B) Build final binomial-based stats
141
+ self._build_final_outputs()
142
+
143
+ # Step C) Plot
144
+ if self.plot_img:
145
+ self._plot_all_steps(step)
146
+ return self.final_markers_24x24
147
+
148
+ # --------------------------
149
+ # Watershed pipeline
150
+ # --------------------------
151
+
152
+ def _pad_image(self, image, pad_size=1):
153
+ return np.pad(image, pad_size, mode='edge')
154
+
155
+ def _crop_image(self, image, pad_size=1):
156
+ return image[pad_size:-pad_size, pad_size:-pad_size]
157
+
158
+ def _watershed_on_padded(self):
159
+ padded = self._pad_image(self.raw_heatmap, pad_size=1)
160
+ self.padded_heatmap_26x26 = padded
161
+
162
+ # Normalize
163
+ norm_padded = (
164
+ (padded - padded.min()) / (padded.max() - padded.min() + 1e-9)
165
+ * 255
166
+ ).astype(np.uint8)
167
+
168
+ # (2) Blur
169
+ if self.blur_kernel != (0,0):
170
+ self.blurred_heatmap_26x26 = cv2.GaussianBlur(norm_padded, self.blur_kernel, 0)
171
+ else:
172
+ self.blurred_heatmap_26x26 = norm_padded.copy()
173
+
174
+ # (3) Binary
175
+ _, bin_mask = cv2.threshold(
176
+ self.blurred_heatmap_26x26,
177
+ 0, 255,
178
+ cv2.THRESH_BINARY + cv2.THRESH_OTSU
179
+ )
180
+ self.binary_mask_26x26 = bin_mask
181
+
182
+ # (4) Distance transform
183
+ dist_transform = cv2.distanceTransform(bin_mask, cv2.DIST_L2, 5)
184
+ self.dist_transform_26x26 = dist_transform
185
+
186
+ # (5) Sure FG
187
+ _, sure_fg = cv2.threshold(
188
+ dist_transform,
189
+ self.sure_fg_factor * dist_transform.max(),
190
+ 255, 0
191
+ )
192
+ sure_fg = np.uint8(sure_fg)
193
+ self.sure_fg_mask_26x26 = sure_fg
194
+
195
+ # sure BG
196
+ kernel = np.ones(self.dilation_kernel_size, np.uint8)
197
+ sure_bg = cv2.dilate(bin_mask, kernel, iterations=self.dilation_iters)
198
+ unknown = cv2.subtract(sure_bg, sure_fg)
199
+
200
+ # Markers
201
+ _, markers = cv2.connectedComponents(sure_fg)
202
+ markers += 1
203
+ markers[unknown==255] = 0
204
+
205
+ # (6) Watershed
206
+ color_img = cv2.applyColorMap(self.blurred_heatmap_26x26, cv2.COLORMAP_VIRIDIS)
207
+ markers = cv2.watershed(color_img, markers)
208
+ self.watershed_markers_26x26 = markers
209
+
210
+ self.color_watershed_raw_26x26 = self._color_label_map(markers)
211
+
212
+ def _postprocess_cropped_markers(self):
213
+ cropped = self._crop_image(self.watershed_markers_26x26, 1)
214
+ cropped = self._fix_boundary_labels(cropped)
215
+ cropped = self._separate_subregions(cropped)
216
+ cropped = self._merge_small_regions(cropped, self.region_size_thresh)
217
+ return cropped
218
+
219
+ def _fix_boundary_labels(self, markers):
220
+ b_idx = np.argwhere(markers==-1)
221
+ neighbors_8 = [
222
+ (-1,-1), (-1,0), (-1,1),
223
+ ( 0,-1), ( 0,1),
224
+ ( 1,-1), ( 1,0), ( 1,1)
225
+ ]
226
+ for (r,c) in b_idx:
227
+ neigh_labs=[]
228
+ for dr,dc in neighbors_8:
229
+ rr,cc=r+dr,c+dc
230
+ if 0<=rr<markers.shape[0] and 0<=cc<markers.shape[1]:
231
+ lb=markers[rr,cc]
232
+ if lb>0:
233
+ neigh_labs.append(lb)
234
+ if neigh_labs:
235
+ markers[r,c]=np.bincount(neigh_labs).argmax()
236
+ else:
237
+ markers[r,c]=0
238
+ return markers
239
+
240
+ def _separate_subregions(self, markers):
241
+ refined = np.zeros_like(markers, dtype=np.int32)
242
+ refined[markers==-1] = -1
243
+ refined[markers==0] = 0
244
+
245
+ next_label=1
246
+ labs = np.unique(markers)
247
+ for lb in labs:
248
+ if lb<=0:
249
+ continue
250
+ region_mask=(markers==lb).astype(np.uint8)
251
+ n_sub, sub_markers = cv2.connectedComponents(region_mask)
252
+ for sub_id in range(1,n_sub):
253
+ subregion_mask=(sub_markers==sub_id)
254
+ refined[subregion_mask]=next_label
255
+ next_label+=1
256
+ return refined
257
+
258
+
259
+ def _merge_small_regions(self, markers, min_patch_size, n_iter=1, ignore_label=-99):
260
+ """
261
+ Combined approach using:
262
+ (A) Pixel-wise smoothing of small patches (iterative),
263
+ but now using 8-connected neighbors.
264
+ (B) After that, remove any connected component whose
265
+ size < self.region_size_thresh by merging it into
266
+ the majority label among its 8-connected boundary.
267
+
268
+ 1) We do not change the pixel-wise smoothing logic, except
269
+ that we consider 8 neighbors instead of 4.
270
+ 2) Then we do a region-based check for all subregions whose
271
+ size < region_size_thresh. Each such region is assigned
272
+ to the 8-neighbor majority among that region's boundary
273
+ pixels. If a tie occurs, we pick the largest label ID
274
+ or keep the original? (You can define your tie logic
275
+ as you wish, shown here picking the top or random.)
276
+ """
277
+
278
+ H, W = markers.shape
279
+
280
+ # -----------------------------------------------------------
281
+ # (A) Pixel-wise smoothing with 8-connected neighbors
282
+ # for small sub-patches of each label
283
+ # -----------------------------------------------------------
284
+
285
+ markers = ndimage.median_filter(markers, size=3)
286
+
287
+ # -----------------------------------------------------------
288
+ # (B) Now remove entire subregions < min_patch_size by
289
+ # reassigning them to the 8-neighbor majority of boundary
290
+ # -----------------------------------------------------------
291
+
292
+ # smoothed = markers.copy()
293
+ structure_8 = np.array([[1,1,1],
294
+ [1,1,1],
295
+ [1,1,1]], dtype=np.uint8)
296
+
297
+ # identify connected components of each label
298
+ # step: we'll do a multi-pass until no small region is left
299
+ changed = True
300
+ while changed:
301
+ changed = False
302
+ label_map = markers
303
+ labels = np.unique(label_map[label_map>0])
304
+ for label in labels:
305
+ # If label patches is > threshold
306
+ if label < 0 or np.sum(label_map == label) >= min_patch_size:
307
+ continue
308
+
309
+ # find connected components of 'label'
310
+ mask = (label_map == label)
311
+ labeled_cc, num_cc = ndimage.label(mask, structure=structure_8)
312
+ if num_cc < 1:
313
+ continue
314
+
315
+ for cc_id in range(1, num_cc+1):
316
+ cc_mask = (labeled_cc == cc_id)
317
+ # size_cc = np.sum(cc_mask)
318
+ # if 0<size_cc<min_patch_size:
319
+ # compute the 8-neighbor majority from boundary
320
+ boundary_pixels = []
321
+ coords = np.argwhere(cc_mask)
322
+ for (r,c) in coords:
323
+ # check if (r,c) is on boundary => or we can check neighbors
324
+ for dr in (-1,0,1):
325
+ for dc in (-1,0,1):
326
+ if dr==0 and dc==0:
327
+ continue
328
+ rr,cc2 = r+dr, c+dc
329
+ if 0<=rr<H and 0<=cc2<W:
330
+ # If neighbor label != label, it's a boundary
331
+ neighbor_label = label_map[rr, cc2]
332
+ if neighbor_label != label and neighbor_label > 0:
333
+ boundary_pixels.append(neighbor_label)
334
+
335
+ if boundary_pixels:
336
+ counter = Counter(boundary_pixels)
337
+ # remove any '0' or negative labels
338
+ # for k in list(counter.keys()):
339
+ # if k<=0:
340
+ # del counter[k]
341
+ if len(counter)>0:
342
+ # pick the label with the largest count
343
+ (new_lab, _cnt) = counter.most_common(1)[0]
344
+ label_map[cc_mask] = new_lab
345
+ changed = True
346
+ else:
347
+ # no valid neighbor => keep the label, or reassign to 0
348
+ label_map[cc_mask] = 0
349
+ changed = True
350
+
351
+ return label_map
352
+
353
+ # def largest_component_mask(self, label_map, label):
354
+ # """Find largest connected component for 'label' using 8-neighbor structure."""
355
+ # mask = (label_map == label)
356
+ # structure_8 = np.array([[1,1,1],
357
+ # [1,1,1],
358
+ # [1,1,1]], dtype=np.uint8)
359
+ # labeled, num = ndimage.label(mask, structure=structure_8)
360
+ # if num < 1:
361
+ # return np.zeros_like(mask, dtype=bool)
362
+ # # measure region sizes
363
+ # sizes = ndimage.sum(mask, labeled, range(1,num+1))
364
+ # largest_id = (np.argmax(sizes) + 1)
365
+ # return (labeled == largest_id)
366
+
367
+ # NOT IN USE
368
+ def _merge_small_regions_smoothing(self, markers):
369
+ label_sizes={}
370
+ labs=np.unique(markers)
371
+ for lb in labs:
372
+ if lb>0:
373
+ label_sizes[lb]=np.sum(markers==lb)
374
+ sorted_labels=sorted(label_sizes, key=lambda x: label_sizes[x])
375
+
376
+ neighbors_8=[
377
+ (-1,-1), (-1,0), (-1,1),
378
+ ( 0,-1), ( 0,1),
379
+ ( 1,-1), ( 1,0), ( 1,1)
380
+ ]
381
+ for lb in sorted_labels:
382
+ size=label_sizes[lb]
383
+ if 0<size<self.region_size_thresh:
384
+ coords=np.argwhere(markers==lb)
385
+ nbr_set=set()
386
+ for (r,c) in coords:
387
+ for dr,dc in neighbors_8:
388
+ rr,cc=r+dr,c+dc
389
+ if 0<=rr<markers.shape[0] and 0<=cc<markers.shape[1]:
390
+ nl=markers[rr,cc]
391
+ if nl>0 and nl!=lb:
392
+ nbr_set.add(nl)
393
+ if not nbr_set:
394
+ continue
395
+ max_lb=None
396
+ max_sz=-1
397
+ for nlb in nbr_set:
398
+ s_nlb=label_sizes.get(nlb,0)
399
+ if s_nlb>max_sz:
400
+ max_sz=s_nlb
401
+ max_lb=nlb
402
+ markers[markers==lb]=max_lb
403
+ label_sizes[max_lb]+=size
404
+ label_sizes[lb]=0
405
+ return markers
406
+
407
+ # --------------------------
408
+ # Build binomial-based stats
409
+ # --------------------------
410
+
411
+ def _build_final_outputs(self):
412
+ """
413
+ 1) color-labeled final overlay
414
+ 2) visited overlay
415
+ 3) scale region probabilities => sum of E_i = N
416
+ 4) store visited fraction, expected #targets, #found, steps next
417
+ """
418
+ markers=self.final_markers_24x24
419
+ valid_labels=np.unique(markers[markers>0])
420
+
421
+ # color-labeled overlay
422
+ color_map=self._color_label_map(markers)
423
+ self.final_output_no_vis_24x24=self._overlay_heatmap_and_labels(self.raw_heatmap, color_map)
424
+
425
+ # visited
426
+ visited_mask=np.zeros(markers.shape,dtype=bool)
427
+ h,w=markers.shape
428
+ for idx in self.flattened_visits:
429
+ rr=idx//w
430
+ cc=idx%w
431
+ if 0<=rr<h and 0<=cc<w:
432
+ visited_mask[rr,cc]=True
433
+
434
+ with_vis=self.final_output_no_vis_24x24.copy()
435
+ red_overlay=np.zeros_like(with_vis)
436
+ alpha=0.3
437
+ for idx in self.flattened_visits:
438
+ rr=idx//w
439
+ cc=idx%w
440
+ if 0<=rr<h and 0<=cc<w:
441
+ red_overlay[rr,cc]=[255,0,0]
442
+ self.final_output_with_vis_24x24=cv2.addWeighted(with_vis,1.0,red_overlay,alpha,0.0)
443
+
444
+ # gather region sums
445
+ label_sizes={}
446
+ label_sums={}
447
+ total_mass=0.0
448
+ for lb in valid_labels:
449
+ mask=(markers==lb)
450
+ size_i=np.sum(mask)
451
+ label_sizes[lb]=size_i
452
+ s=np.sum(self.raw_heatmap[mask])
453
+ label_sums[lb]=s
454
+ total_mass+=s
455
+
456
+ # scale factor
457
+ scale=self.N/(total_mass+1e-9)
458
+
459
+ # how many visited patches in each region
460
+ region_visited={}
461
+ for lb in valid_labels:
462
+ region_visited[lb]=np.sum((markers==lb) & visited_mask)
463
+
464
+ # how many targets found in each region
465
+ region_found_count={lb:0 for lb in valid_labels}
466
+ for idx, num_found in zip(self.flattened_visits,self.num_found_list):
467
+ rr=idx//w
468
+ cc=idx%w
469
+ lb=markers[rr,cc]
470
+ if lb>0 and num_found>0:
471
+ region_found_count[lb]+=num_found
472
+
473
+ # store stats
474
+ for lb in valid_labels:
475
+ size_i=label_sizes[lb]
476
+ raw_sum=label_sums[lb]
477
+ if size_i>0:
478
+ p_i=scale*(raw_sum/size_i)
479
+ else:
480
+ p_i=0.0
481
+ self.region_prob[lb]=p_i
482
+
483
+ visited_num=region_visited[lb]
484
+ # visited fraction as a string, e.g. "23/100"
485
+ self.visited_num[lb]=visited_num
486
+ self.region_visited_str[lb]=f"{visited_num}/{size_i}"
487
+
488
+ E_i=size_i*p_i
489
+ self.region_expected[lb]=E_i
490
+
491
+ fcount=region_found_count[lb]
492
+ self.region_found[lb]=fcount
493
+
494
+ self.region_centroids[lb] = self._find_centroid_in_region(markers, lb)
495
+
496
+ # steps for the (found+1)-th target => size_i/E_i = 1/p_i
497
+ if E_i>0:
498
+ next_steps=(fcount+1)*size_i/E_i
499
+ else:
500
+ next_steps=float('inf')
501
+ self.region_steps_next[lb]=next_steps
502
+
503
+
504
+ def _find_centroid_in_region(self, markers, label_id):
505
+ coords = np.argwhere(markers == label_id)
506
+ if coords.size == 0:
507
+ return None
508
+ r_mean = np.mean(coords[:, 0])
509
+ c_mean = np.mean(coords[:, 1])
510
+ dists = (coords[:, 0] - r_mean)**2 + (coords[:, 1] - c_mean)**2
511
+ idx = np.argmin(dists)
512
+ return (coords[idx][0], coords[idx][1])
513
+
514
+
515
+ # --------------------------
516
+ # Visualization
517
+ # --------------------------
518
+
519
+ def _color_label_map(self, markers):
520
+ out_img=np.zeros((*markers.shape,3),dtype=np.uint8)
521
+ labs=np.unique(markers[markers>0])
522
+ # col_array=self.colormap(np.linspace(0,1,len(labs)))
523
+
524
+ for r in range(markers.shape[0]):
525
+ for c in range(markers.shape[1]):
526
+ lb = markers[r, c]
527
+ if lb in self.label_to_color:
528
+ out_img[r, c] = self.label_to_color[lb]
529
+
530
+ # OLD: Color keeps jumping
531
+ # lab2col={}
532
+ # for i,lb in enumerate(labs):
533
+ # lab2col[lb]=(col_array[i][:3]*255).astype(np.uint8)
534
+
535
+ # for r in range(markers.shape[0]):
536
+ # for c in range(markers.shape[1]):
537
+ # lb=markers[r,c]
538
+ # if lb in lab2col:
539
+ # out_img[r,c]=lab2col[lb]
540
+ return out_img
541
+
542
+ def _overlay_heatmap_and_labels(self, heatmap_24x24, color_map):
543
+ heatmap_rgb=plt.cm.gray((heatmap_24x24/100.0))[:,:,:3]
544
+ heatmap_rgb=(heatmap_rgb*255).astype(np.uint8)
545
+ alpha=0.5
546
+ blended=cv2.addWeighted(heatmap_rgb,1.0,color_map,alpha,0.0)
547
+ return blended
548
+
549
+ def _plot_all_steps(self, step):
550
+ fig, axs = plt.subplots(2,4, figsize=(20,10))
551
+ fig.suptitle(self.title, fontsize=16)
552
+ fig.text(0.5, 0.05, self._get_subtitles(), ha="center", va="center", fontsize=12)
553
+ plt.tight_layout()
554
+ plt.subplots_adjust(bottom=0.1) # Prevent overlap
555
+
556
+ # (1) Input
557
+ axs[0,0].imshow(self.input_heatmap_24x24, cmap='viridis')
558
+ axs[0,0].set_title("1) Input Heatmap")
559
+ axs[0,0].axis("off")
560
+
561
+ # (2) After Blur
562
+ if self.blurred_heatmap_26x26 is not None:
563
+ axs[0,1].imshow(self.blurred_heatmap_26x26, cmap='gray')
564
+ else:
565
+ axs[0,1].text(0.2,0.5,"No blur", fontsize=12)
566
+ axs[0,1].set_title("2) After Gaussian Blur")
567
+ axs[0,1].axis("off")
568
+
569
+ # (3) Binary Mask
570
+ if self.binary_mask_26x26 is not None:
571
+ axs[0,2].imshow(self.binary_mask_26x26, cmap='gray')
572
+ else:
573
+ axs[0,2].text(0.2,0.5,"No binary", fontsize=12)
574
+ axs[0,2].set_title("3) Binary Mask")
575
+ axs[0,2].axis("off")
576
+
577
+ # (4) Distance Transform
578
+ if self.dist_transform_26x26 is not None:
579
+ axs[0,3].imshow(self.dist_transform_26x26, cmap='jet')
580
+ else:
581
+ axs[0,3].text(0.2,0.5,"No dist transform", fontsize=12)
582
+ axs[0,3].set_title("4) Distance Transform")
583
+ axs[0,3].axis("off")
584
+
585
+ # (5) Sure Foreground
586
+ if self.sure_fg_mask_26x26 is not None:
587
+ axs[1,0].imshow(self.sure_fg_mask_26x26, cmap='gray')
588
+ else:
589
+ axs[1,0].text(0.2,0.5,"No sure-FG", fontsize=12)
590
+ axs[1,0].set_title("5) Sure Foreground Mask")
591
+ axs[1,0].axis("off")
592
+
593
+ # (6) Raw Watershed
594
+ if self.color_watershed_raw_26x26 is not None:
595
+ axs[1,1].imshow(self.color_watershed_raw_26x26)
596
+ else:
597
+ axs[1,1].text(0.2,0.5,"No watershed", fontsize=12)
598
+ axs[1,1].set_title("6) Raw Watershed Output")
599
+ axs[1,1].axis("off")
600
+
601
+ # (7) Final (No Vis)
602
+ if self.final_output_no_vis_24x24 is not None:
603
+ axs[1,2].imshow(self.final_output_no_vis_24x24)
604
+ self._plot_centroids(axs[1,2])
605
+ self._add_legend(axs[1,2])
606
+ else:
607
+ axs[1,2].text(0.2,0.5,"No final", fontsize=12)
608
+ axs[1,2].set_title("7) Final Output (No Vis)")
609
+ axs[1,2].axis("off")
610
+
611
+ # (8) Final (With Vis)
612
+ if self.final_output_no_vis_24x24 is not None:
613
+ axs[1,3].imshow(self.final_output_no_vis_24x24) # final_output_with_vis_24x24
614
+ # Red line for path
615
+ path_rows, path_cols = [], []
616
+ targets = []
617
+ for i, idx in enumerate(self.flattened_visits):
618
+ rr = idx // self.final_output_no_vis_24x24.shape[1]
619
+ cc = idx % self.final_output_no_vis_24x24.shape[1]
620
+ path_rows.append(rr)
621
+ path_cols.append(cc)
622
+ if self.num_found_list[i] > 0:
623
+ targets.append((rr, cc))
624
+ axs[1,3].plot(path_cols, path_rows, c="r", linewidth=2)
625
+ axs[1,3].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
626
+ axs[1,3].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
627
+ for target in targets:
628
+ axs[1,3].plot(target[1], target[0], color='g', marker='x', linestyle='-', mec="black", markersize=12, markeredgewidth=4, zorder=999)
629
+ self._plot_centroids(axs[1,3])
630
+ self._add_legend(axs[1,3])
631
+ else:
632
+ axs[1,3].text(0.2,0.5,"No visited out", fontsize=12)
633
+ axs[1,3].set_title("8) Final Output (With Vis)")
634
+ axs[1,3].axis("off")
635
+
636
+ # plt.show()
637
+
638
+ if not os.path.exists(self.gifs_dir):
639
+ os.makedirs(self.gifs_dir)
640
+
641
+ plt.savefig(f'{self.gifs_dir}/watershed_{step}.png'.format(dpi=150))
642
+ self.segmentor_frame_files.append(f'{self.gifs_dir}/watershed_{step}.png')
643
+ plt.close()
644
+
645
+
646
+ # def _plot_centroids(self, ax):
647
+ # markers=self.final_markers_24x24
648
+ # labs=np.unique(markers[markers>0])
649
+ # for lb in labs:
650
+ # coords=np.argwhere(markers==lb)
651
+ # if coords.size>0:
652
+ # r_mean=np.mean(coords[:,0])
653
+ # c_mean=np.mean(coords[:,1])
654
+ # ax.plot(c_mean, r_mean, 'x', color='white', markersize=7)
655
+ def _plot_centroids(self, ax):
656
+ markers = self.final_markers_24x24
657
+ labs = np.unique(markers[markers > 0])
658
+ for lb in labs:
659
+ # Use the stored centroid to ensure it lies within the region.
660
+ cent = self.region_centroids.get(lb, None)
661
+ if cent is not None:
662
+ ax.plot(cent[1], cent[0], 'x', color='white', markersize=7)
663
+
664
+ def _add_legend(self, ax):
665
+ markers=self.final_markers_24x24
666
+ labs=np.unique(markers[markers>0])
667
+ # col_array=self.colormap(np.linspace(0,1,len(labs)))
668
+
669
+ handles=[]
670
+ for i,lab in enumerate(labs):
671
+ color=self.label_to_color[lab] / 255.0
672
+ label_str = f"R{lab}"
673
+ h=plt.Line2D(
674
+ [0],[0],
675
+ marker='o',
676
+ color='w',
677
+ markerfacecolor=color, # col_array[i][:3],
678
+ markersize=10,
679
+ label=label_str
680
+ )
681
+ handles.append(h)
682
+
683
+ ax.legend(handles=handles, loc='upper right', fontsize=8, frameon=True)
684
+
685
+
686
+ def _get_subtitles(self):
687
+ markers=self.final_markers_24x24
688
+ labs=np.unique(markers[markers>0])
689
+
690
+ suptitle_str=""
691
+ for i,lab in enumerate(labs):
692
+ prob_val=self.region_prob.get(lab,0.0)
693
+ visit_str=self.region_visited_str.get(lab,"0/0")
694
+ E_i=self.region_expected.get(lab,0.0)
695
+ num_found=self.region_found.get(lab,0)
696
+ steps_next=self.region_steps_next.get(lab,float('inf'))
697
+
698
+ if math.isinf(steps_next):
699
+ steps_str="∞"
700
+ else:
701
+ steps_str=f"{steps_next:.2f}"
702
+
703
+ label_str = (
704
+ f"R{lab}: p={prob_val:.3f}, Vis={visit_str}, "
705
+ f"E_tgts={E_i:.2f}, #found={num_found}, "
706
+ f"E_steps({num_found+1}th)={steps_str}"
707
+ )
708
+ suptitle_str+=label_str+"\n"
709
+
710
+ return suptitle_str
711
+
712
+
713
+ def incorporate_kmeans_boundaries(self, kmeans_label_map, step=0):
714
+ """
715
+ Merges the existing final watershed map (self.final_markers_24x24) with a
716
+ k-means label map of the same shape. Each watershed region is subdivided by
717
+ the boundaries from the k-means map to produce a new, finer segmentation.
718
+
719
+ Then we recompute region stats (probabilities, centroids, etc.) and plot a
720
+ 2×2 figure:
721
+ 1) Original Heatmap
722
+ 2) Old Watershed Segmentation
723
+ 3) k-Means Label Map
724
+ 4) Merged Segmentation
725
+ """
726
+
727
+ # ---------------------------
728
+ # 1) Ensure shapes match
729
+ # ---------------------------
730
+ if kmeans_label_map.shape != self.final_markers_24x24.shape:
731
+ raise ValueError("kmeans_label_map must match the shape of the final watershed map.")
732
+
733
+ # We produce a new merged label map (24×24) subdividing each watershed region
734
+ # by the kmeans labels. For example, if watershed label L intersects multiple
735
+ # kmeans labels, that region is split into multiple subregions.
736
+
737
+ # We'll store it in self.merged_markers_24x24
738
+ self.merged_markers_24x24 = np.zeros_like(self.final_markers_24x24, dtype=np.int32)
739
+
740
+ watershed_labels = np.unique(self.final_markers_24x24[self.final_markers_24x24 > 0])
741
+ next_label = 1
742
+
743
+ for wlab in watershed_labels:
744
+ # Mask for watershed label wlab
745
+ wmask = (self.final_markers_24x24 == wlab)
746
+
747
+ # Among those pixels, we see which kmeans labels appear
748
+ kmeans_vals = np.unique(kmeans_label_map[wmask])
749
+ for klabel in kmeans_vals:
750
+ # Intersection => subregion
751
+ subregion_mask = wmask & (kmeans_label_map == klabel)
752
+ if not np.any(subregion_mask):
753
+ continue
754
+ # Assign a new label
755
+ self.merged_markers_24x24[subregion_mask] = next_label
756
+ next_label += 1
757
+
758
+ # Optional: If you want to merge small subregions again:
759
+ self.merged_markers_24x24 = self._merge_small_regions(self.merged_markers_24x24, self.region_size_thresh, n_iter=1)
760
+
761
+ # Recompute all stats for the merged segmentation:
762
+ self._build_merged_outputs()
763
+
764
+ # Finally, plot the 2×2 figure:
765
+ if self.plot_img:
766
+ self._plot_kmeans_merge(kmeans_label_map, step)
767
+
768
+ return self.merged_markers_24x24
769
+
770
+
771
+ def _build_merged_outputs(self):
772
+ """
773
+ Re-run binomial logic on self.merged_markers_24x24:
774
+ 1) build color overlay
775
+ 2) scale probabilities
776
+ 3) compute visited fraction, found targets, centroids, etc.
777
+ We'll store them in new variables (prefixed 'merged_') so we don't overwrite
778
+ the original watershed data.
779
+ """
780
+
781
+ markers = self.merged_markers_24x24
782
+ valid_labels = np.unique(markers[markers > 0])
783
+
784
+ # 1) color-labeled overlay
785
+ self.merged_output_color = self._color_label_map(markers)
786
+ self.merged_output_no_vis = self._overlay_heatmap_and_labels(self.raw_heatmap, self.merged_output_color)
787
+
788
+ # 2) visited overlay
789
+ h, w = markers.shape
790
+ visited_mask = np.zeros(markers.shape, dtype=bool)
791
+ for idx in self.flattened_visits:
792
+ rr = idx // w
793
+ cc = idx % w
794
+ if 0 <= rr < h and 0 <= cc < w:
795
+ visited_mask[rr, cc] = True
796
+
797
+ merged_with_vis = self.merged_output_no_vis.copy()
798
+ red_overlay = np.zeros_like(merged_with_vis)
799
+ alpha = 0.3
800
+ for idx in self.flattened_visits:
801
+ rr = idx // w
802
+ cc = idx % w
803
+ if 0 <= rr < h and 0 <= cc < w:
804
+ red_overlay[rr, cc] = [255, 0, 0]
805
+ self.merged_output_with_vis = cv2.addWeighted(merged_with_vis, 1.0, red_overlay, alpha, 0.0)
806
+
807
+ # 3) compute scale factor
808
+ label_sizes = {}
809
+ label_sums = {}
810
+ total_mass = 0.0
811
+ for lb in valid_labels:
812
+ mask = (markers == lb)
813
+ size_i = np.sum(mask)
814
+ label_sizes[lb] = size_i
815
+ s = np.sum(self.raw_heatmap[mask])
816
+ label_sums[lb] = s
817
+ total_mass += s
818
+
819
+ scale = self.N/(total_mass + 1e-9)
820
+
821
+ # 4) visited fraction, found count, etc.
822
+ merged_visited = {}
823
+ for lb in valid_labels:
824
+ merged_visited[lb] = np.sum((markers == lb) & visited_mask)
825
+
826
+ # found_count
827
+ merged_found_count = {lb:0 for lb in valid_labels}
828
+ for idx, num_found in zip(self.flattened_visits, self.num_found_list):
829
+ rr = idx // w
830
+ cc = idx % w
831
+ lb = markers[rr, cc]
832
+ if lb > 0 and num_found > 0:
833
+ merged_found_count[lb] += num_found
834
+
835
+ # store them
836
+ self.merged_region_prob = {}
837
+ self.merged_region_visited_str = {}
838
+ self.merged_region_expected = {}
839
+ self.merged_region_found = {}
840
+ self.merged_region_steps_next = {}
841
+ self.merged_region_centroids = {}
842
+
843
+ for lb in valid_labels:
844
+ size_i = label_sizes[lb]
845
+ raw_sum = label_sums[lb]
846
+ if size_i > 0:
847
+ p_i = scale*(raw_sum/size_i)
848
+ else:
849
+ p_i = 0.0
850
+ self.merged_region_prob[lb] = p_i
851
+
852
+ visited_num = merged_visited[lb]
853
+ self.merged_visited_num[lb]=visited_num
854
+ self.merged_region_visited_str[lb] = f"{visited_num}/{size_i}"
855
+
856
+ E_i = size_i*p_i
857
+ self.merged_region_expected[lb] = E_i
858
+
859
+ fcount = merged_found_count[lb]
860
+ self.merged_region_found[lb] = fcount
861
+
862
+ # recompute centroid for merged label
863
+ cent = self._find_centroid_in_region(markers, lb)
864
+ self.merged_region_centroids[lb] = cent
865
+
866
+ if E_i > 0:
867
+ next_steps = (fcount+1)*size_i/E_i
868
+ else:
869
+ next_steps = float('inf')
870
+ self.merged_region_steps_next[lb] = next_steps
871
+
872
+
873
+
874
+ def _plot_kmeans_merge(self, kmeans_label_map, step):
875
+ """
876
+ Plots a 2×2 figure:
877
+ (1) Original Heatmap
878
+ (2) Old Final Markers (Watershed)
879
+ (3) k-means Label Map
880
+ (4) Merged Markers
881
+ """
882
+
883
+
884
+ fig, axs = plt.subplots(2, 2, figsize=(12, 12))
885
+ fig.suptitle("Merging K-Means Boundaries", fontsize=16)
886
+ fig.text(0.5, 0.05, self._get_merged_subtitles(), ha="center", va="center", fontsize=12)
887
+ plt.tight_layout()
888
+ plt.subplots_adjust(bottom=0.1) # Prevent overlap
889
+
890
+ # (1) Original heatmap
891
+ axs[0,0].imshow(self.input_heatmap_24x24, cmap='viridis')
892
+ axs[0,0].set_title("1) Original Heatmap")
893
+ axs[0,0].axis("off")
894
+
895
+ # (2) Watershed Segmentation
896
+ old_color = self._color_label_map(self.final_markers_24x24)
897
+ old_overlay = self._overlay_heatmap_and_labels(self.raw_heatmap, old_color)
898
+ axs[0,1].imshow(old_overlay)
899
+ axs[0,1].set_title("2) Watershed Segmentation")
900
+ axs[0,1].axis("off")
901
+ # -------------- NEW: Add the merged legend to the watershed subplot --------------
902
+ self._add_merged_legend(axs[0,1])
903
+ # -------------------------------------------------------------------------------
904
+
905
+ # (3) k-Means Label Map
906
+ kmeans_color = self._color_label_map(kmeans_label_map)
907
+ kmeans_overlay = self._overlay_heatmap_and_labels(self.raw_heatmap, kmeans_color)
908
+ axs[1,0].imshow(kmeans_overlay)
909
+ axs[1,0].set_title("3) K-Means Label Map")
910
+ axs[1,0].axis("off")
911
+
912
+ # (4) Merged Markers
913
+ # We already have self.merged_output_with_vis, but we'll now add a pink overlay for the path
914
+ # merged_with_vis = self.merged_output_with_vis.copy()
915
+
916
+ # Red line for path
917
+ axs[1,1].imshow(self.merged_output_no_vis)
918
+ path_rows, path_cols = [], []
919
+ targets = []
920
+ for i, idx in enumerate(self.flattened_visits):
921
+ rr = idx // self.merged_markers_24x24.shape[1]
922
+ cc = idx % self.merged_markers_24x24.shape[1]
923
+ path_rows.append(rr)
924
+ path_cols.append(cc)
925
+ if self.num_found_list[i] > 0:
926
+ targets.append((rr, cc))
927
+ axs[1,1].plot(path_cols, path_rows, c="r", linewidth=2)
928
+ axs[1,1].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
929
+ axs[1,1].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
930
+ for target in targets:
931
+ axs[1,1].plot(target[1], target[0], color='g', marker='x', linestyle='-', mec="black", markersize=12, markeredgewidth=4, zorder=999)
932
+
933
+ axs[1,1].set_title("4) Merged Map")
934
+ axs[1,1].axis("off")
935
+
936
+ # Mark centroids for merged subregions
937
+ merged_labs = np.unique(self.merged_markers_24x24[self.merged_markers_24x24 > 0])
938
+ for lab in merged_labs:
939
+ cent = self.merged_region_centroids.get(lab, None)
940
+ if cent is not None:
941
+ axs[1,1].plot(cent[1], cent[0], 'x', color='white', markersize=7)
942
+
943
+ # Add the merged legend to the final subplot as well
944
+ self._add_merged_legend(axs[1,1])
945
+
946
+ # Save figure (same naming logic as your code)
947
+ if not os.path.exists(self.gifs_dir):
948
+ os.makedirs(self.gifs_dir)
949
+
950
+ plt.savefig(f'{self.gifs_dir}/kmeans_merged_{step}.png', dpi=150)
951
+ self.kmeans_merged_frame_files.append(f'{self.gifs_dir}/kmeans_merged_{step}.png')
952
+ plt.close()
953
+
954
+
955
+ def _add_merged_legend(self, ax):
956
+ """
957
+ Similar to _add_legend, but pulling data from self.merged_region_* variables.
958
+ """
959
+ markers = self.merged_markers_24x24
960
+ labs = np.unique(markers[markers>0])
961
+ # col_array = self.colormap(np.linspace(0,1,len(labs)))
962
+
963
+ handles = []
964
+ for i,lab in enumerate(labs):
965
+ color = self.label_to_color[lab] / 255.0
966
+ label_str = f"R{lab}"
967
+ h=plt.Line2D(
968
+ [0],[0],
969
+ marker='o',
970
+ color='w',
971
+ markerfacecolor=color, #col_array[i][:3],
972
+ markersize=10,
973
+ label=label_str
974
+ )
975
+ handles.append(h)
976
+
977
+ ax.legend(handles=handles, loc='upper right', fontsize=8, frameon=True)
978
+
979
+ # TODO: Merge with _get_subtitles
980
+ def _get_merged_subtitles(self):
981
+ markers=self.merged_markers_24x24
982
+ labs=np.unique(markers[markers>0])
983
+
984
+ suptitle_str=""
985
+ for i,lab in enumerate(labs):
986
+ prob_val=self.merged_region_prob.get(lab,0.0)
987
+ visit_str=self.merged_region_visited_str.get(lab,"0/0")
988
+ E_i=self.merged_region_expected.get(lab,0.0)
989
+ num_found=self.merged_region_found.get(lab,0)
990
+ steps_next=self.merged_region_steps_next.get(lab,float('inf'))
991
+
992
+ if math.isinf(steps_next):
993
+ steps_str="∞"
994
+ else:
995
+ steps_str=f"{steps_next:.2f}"
996
+
997
+ label_str = (
998
+ f"R{lab}: p={prob_val:.3f}, Vis={visit_str}, "
999
+ f"E_tgts={E_i:.2f}, #found={num_found}, "
1000
+ f"E_steps({num_found+1}th)={steps_str}"
1001
+ )
1002
+ suptitle_str+=label_str+"\n"
1003
+
1004
+ return suptitle_str
1005
+
1006
+
1007
+ # ------ OTHER HELPER FUNCTIONS -------#
1008
+ def get_label_id(self, patch_idx):
1009
+ """
1010
+ Given a flattened index (row-major order) within the 24×24 grid,
1011
+ return the average probability of the corresponding merged region.
1012
+
1013
+ :param patch_idx: Flattened index (0 to H×W-1).
1014
+ :return: The region-average probability, or 0.0 if the patch
1015
+ is out of bounds or not in a labeled region.
1016
+ """
1017
+ # Guard: check we have a merged label map
1018
+ if self.merged_markers_24x24 is None:
1019
+ raise ValueError("Merged markers do not exist. "
1020
+ "Please call `incorporate_kmeans_boundaries(...)` first.")
1021
+
1022
+ h, w = self.merged_markers_24x24.shape
1023
+ # Derive row, col
1024
+ row = patch_idx // w
1025
+ col = patch_idx % w
1026
+
1027
+ if row < 0 or row >= h or col < 0 or col >= w:
1028
+ return 0.0
1029
+
1030
+ # Find which merged label region that pixel belongs to
1031
+ label_id = self.merged_markers_24x24[row, col]
1032
+
1033
+ return label_id
1034
+
1035
+ # ----------------------------
1036
+ # Example MAIN usage
1037
+ # ----------------------------
1038
+ if __name__=="__main__":
1039
+ # 1) Load a 24x24 heatmap
1040
+ file_path="./expt/seg_mask_step0_v1.npy"
1041
+ heatmap_24x24=np.load(file_path)
1042
+ heatmap_24x24=np.clip(heatmap_24x24,0,100)
1043
+
1044
+ # 2) Simulate a random walk
1045
+ max_r, max_c=heatmap_24x24.shape[0]-1, heatmap_24x24.shape[1]-1
1046
+ steps=100
1047
+ current_pos=(0,0)
1048
+ moves_8=[(-1,-1),(-1,0),(-1,1),
1049
+ ( 0,-1), ( 0,1),
1050
+ ( 1,-1),( 1,0), ( 1,1)]
1051
+ path_coords=[current_pos]
1052
+ num_found_list=[0]
1053
+ for _ in range(steps):
1054
+ r,c=current_pos
1055
+ dr,dc=random.choice(moves_8)
1056
+ nr, nc=r+dr,c+dc
1057
+ if 0<=nr<=max_r and 0<=nc<=max_c:
1058
+ current_pos=(nr,nc)
1059
+ path_coords.append(current_pos)
1060
+ # 10% chance for new target
1061
+ # num_found_list.append(random.random()<0.1)
1062
+ num_found_list.append(0)
1063
+ num_found_list[0]=1
1064
+ num_found_list[1]=2
1065
+ num_found_list[2]=3
1066
+
1067
+ # Flatten
1068
+ width=heatmap_24x24.shape[1]
1069
+ flattened_visits=[]
1070
+ for (rr,cc) in path_coords:
1071
+ idx=rr*width+cc
1072
+ flattened_visits.append(idx)
1073
+
1074
+ # 3) Instantiate & run
1075
+ segmenter=WatershedBinomial(
1076
+ N=10.0, # total expected # across map
1077
+ blur_kernel=(5,5),
1078
+ sure_fg_factor=0.5,
1079
+ dilation_kernel_size=(2,2),
1080
+ dilation_iters=3,
1081
+ region_size_thresh=9,
1082
+ plot_img=True,
1083
+ gifs_dir="/home/user/VLM-Search/inference/test_results/gifs",
1084
+ colormap=plt.cm.tab20,
1085
+ title="Watershed + Binomial Steps"
1086
+ )
1087
+ final=segmenter.run(heatmap_24x24, flattened_visits, num_found_list, step=0)
1088
+
1089
+ # Suppose you load a kmeans_label_map in the same shape (24×24)
1090
+ file_path="./expt/smoothed_labels_2d.npy"
1091
+ kmeans_label_map = np.load(file_path)
1092
+
1093
+ # Now incorporate it into the final segmentation:
1094
+ segmenter.incorporate_kmeans_boundaries(kmeans_label_map, step=0)
Taxabind/Taxabind/SoundBind/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ config = edict()
4
+ config.train_df = '/mnt/hdd/inat2021_ds/sound_train/sound_image_pairs_filtered.csv' # 'train_df.json'
5
+ config.val_df= '/mnt/hdd/inat2021_ds/sound_val/sound_image_pairs_filtered.csv' # 'val_df.csv'
6
+ config.data_path = '/mnt/hdd/inat2021_ds/sound_train'
7
+ # HOW ABOUT VAL?
8
+
9
+ # # Toy Example
10
+ # config.train_df = '../scripts/preprocess/expt/sound_train/sound_image_pairs_filtered.csv' # 'train_df.json'
11
+ # config.val_df= '../scripts/preprocess/expt/sound_train/sound_image_pairs_filtered.csv' # 'val_df.csv'
12
+ # config.data_path = '../scripts/preprocess/expt/sound_train'
13
+
14
+ config.batch_size = 256
15
+ config.lr = 1e-4
16
+ config.accumulate_grad_batches = 8
17
+ config.max_epochs = 20
18
+ config.num_workers = 16
19
+ config.devices = 2
20
+ config.val_check_interval = 0.5
21
+
22
+
23
+ config.save_dir = 'checkpoints'
24
+ config.filename = 'soundbind-{epoch:02d}-{val_loss:.2f}'
25
+
26
+ config.locked_tuning = True
27
+
28
+ # TEMP
29
+ config.sat_encoder = 'openai/clip-vit-large-patch14-336' # openai/clip-vit-base-patch16, openai/clip-vit-large-patch14-336
30
+ config.patch_size = 14
Taxabind/Taxabind/SoundBind/dataloader.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision import transforms
3
+ import os
4
+ from PIL import Image
5
+ import pandas as pd
6
+ from config import config
7
+ from sound_encoder import get_audio_clap
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ #img, sound, year, month, day
12
+ class INatDataset(Dataset):
13
+ def __init__(self,
14
+ data_file,
15
+ mode='train'):
16
+ self.data_file = pd.read_csv(data_file)
17
+ if mode=='train':
18
+ self.transform = transforms.Compose([
19
+ transforms.Resize((256, 256)),
20
+ transforms.RandomCrop((224, 224)),
21
+ transforms.RandomHorizontalFlip(0.5),
22
+ transforms.GaussianBlur(5, (0.01, 1.0)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225])
26
+ ])
27
+ else:
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.CenterCrop((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
33
+ std=[0.229, 0.224, 0.225])
34
+ ])
35
+ self.species_text = self.data_file['scientific_name'].tolist()
36
+ self.species_classes = list(set(self.species_text))
37
+
38
+ def __len__(self):
39
+ return len(self.data_file)
40
+
41
+ def get_sample(self,idx):
42
+ sample = self.data_file.iloc[idx]
43
+ id = sample.id
44
+ sound_format = sample.sound_format
45
+ image_path = os.path.join(config['data_path'],"images",str(id)+".jpg")
46
+ sound_path = os.path.join(config['data_path'],"sounds_mp3",str(id)+"."+'mp3')
47
+ sound = get_audio_clap(sound_path) # , sound_format)
48
+
49
+ for k in sound.keys():
50
+ sound[k] = sound[k].squeeze(0)
51
+ image = self.transform(Image.open(image_path))
52
+
53
+ return image, sound
54
+
55
+ def __getitem__(self, idx):
56
+ image, sound = self.get_sample(idx)
57
+ return image, sound, self.species_classes.index(self.data_file.iloc[idx]['scientific_name'])
Taxabind/Taxabind/SoundBind/model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ # Ensure the correct directory is at the front of sys.path
4
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ import open_clip
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn as nn
10
+ from sound_encoder import CLAP_audiomodel_withProjection as AudioEncoder
11
+ import numpy as np
12
+ from torch.utils.data import DataLoader
13
+ from config import config
14
+ import random
15
+ from dataloader import INatDataset
16
+ from pytorch_lightning.callbacks import ModelCheckpoint
17
+
18
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
19
+ audio_loss = contrastive_loss(similarity)
20
+ ground_img_loss = contrastive_loss(similarity.t())
21
+ return 0.5*audio_loss + 0.5*ground_img_loss
22
+
23
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
24
+
25
+ return nn.functional.cross_entropy(logits[:logits.shape[1]], torch.arange(logits.shape[1], device=logits.device))
26
+
27
+ class AudioBind(pl.LightningModule):
28
+ def __init__(self, train_dataset, val_dataset, **kwargs):
29
+ super().__init__()
30
+ self.train_dataset = train_dataset
31
+ self.val_dataset = val_dataset
32
+ self.model, *_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
33
+ if config.locked_tuning:
34
+ for param in self.model.parameters():
35
+ param.requires_grad = False
36
+ self.audio_encoder = AudioEncoder(freeze=False)
37
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
38
+ self.batch_size = kwargs.get('batch_size')
39
+ self.num_workers = kwargs.get('num_workers')
40
+ self.lr = kwargs.get('lr', 1e-4)
41
+
42
+ def forward(self, image, audio):
43
+ with torch.no_grad():
44
+ image_embeds, *_ = self.model(image)
45
+ unnormalized_audio_embeds = self.audio_encoder(audio)
46
+ audio_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
47
+ return image_embeds, audio_embeds
48
+
49
+ def shared_step(self, batch):
50
+ image, audio, *_ = batch
51
+ image_embeds, audio_embeds = self(image, audio)
52
+ logit_scale = self.logit_scale.exp()
53
+ logits_per_img = torch.matmul(image_embeds,audio_embeds.t())*logit_scale
54
+ cross_contrastive_loss = clip_loss(logits_per_img)
55
+ return cross_contrastive_loss
56
+
57
+ def training_step(self, batch, batch_idx):
58
+ loss = self.shared_step(batch)
59
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
60
+ self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
61
+ return loss
62
+
63
+ def on_train_batch_end(self,outputs,batch, batch_idx):
64
+ if self.logit_scale.data > np.log(100):
65
+ self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, np.log(100))
66
+
67
+ def validation_step(self, batch, batch_idx):
68
+ loss = self.shared_step(batch)
69
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
70
+ return loss
71
+
72
+ def train_dataloader(self):
73
+ return DataLoader(self.train_dataset,
74
+ batch_size=self.batch_size,
75
+ num_workers=self.num_workers,
76
+ shuffle=True,
77
+ persistent_workers=False)
78
+
79
+ def val_dataloader(self):
80
+ return DataLoader(self.val_dataset,
81
+ batch_size=self.batch_size,
82
+ num_workers=self.num_workers,
83
+ shuffle=False,
84
+ persistent_workers=False)
85
+
86
+ def configure_optimizers(self):
87
+ params = self.parameters()
88
+ self.optim = torch.optim.AdamW(params,
89
+ lr=self.lr,
90
+ betas=(0.9,0.98),
91
+ eps=1e-6
92
+ )
93
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
94
+ optimizer=self.optim,
95
+ T_0=20
96
+ )
97
+ return [self.optim], [self.scheduler]
98
+
99
+ def seed_everything(seed=42):
100
+ """
101
+ seed: int
102
+ """
103
+ torch.manual_seed(seed)
104
+ torch.cuda.manual_seed_all(seed)
105
+ np.random.seed(seed)
106
+ random.seed(seed)
107
+ torch.backends.cudnn.deterministic = True
108
+ torch.backends.cudnn.benchmark = False
109
+ os.environ["PYTHONHASHSEED"] = str(seed)
110
+
111
+ if __name__=='__main__':
112
+ import warnings
113
+ warnings.filterwarnings("ignore")
114
+ torch.set_warn_always(False)
115
+
116
+ seed_everything()
117
+ train_dataset = INatDataset(data_file=config.train_df, mode='train')
118
+ val_dataset = INatDataset(data_file=config.val_df, mode='val')
119
+ kwargs = {'batch_size':config.batch_size, 'num_workers': config.num_workers}
120
+
121
+ model = AudioBind(train_dataset, val_dataset, **kwargs)
122
+ torch.cuda.empty_cache()
123
+
124
+ checkpoint = ModelCheckpoint(
125
+ monitor='val_loss',
126
+ dirpath=config.save_dir,
127
+ filename=config.filename,
128
+ mode='min',
129
+ save_top_k=3
130
+ )
131
+ trainer = pl.Trainer(
132
+ accelerator='gpu',
133
+ devices=config.devices,
134
+ strategy='ddp',
135
+ max_epochs=config.max_epochs,
136
+ num_nodes=1,
137
+ callbacks=[checkpoint],
138
+ accumulate_grad_batches=config.accumulate_grad_batches,
139
+ log_every_n_steps=1
140
+ )
141
+ trainer.fit(model)
Taxabind/Taxabind/SoundBind/sound_encoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Hugging face way of loading AudioCLAP model
2
+ from transformers import ClapProcessor
3
+ from transformers import ClapAudioModelWithProjection
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ import torch
7
+ import numpy as np
8
+ import torchaudio
9
+ import os
10
+
11
+ processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
12
+ SAMPLE_RATE = 48000
13
+
14
+ def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
15
+ track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
16
+ track = track.mean(axis=0)
17
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
18
+ output = processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
19
+ return output
20
+
21
+
22
+ class CLAP_audiomodel_withProjection(pl.LightningModule):
23
+ def __init__(self,freeze=False):
24
+ super().__init__()
25
+ if freeze:
26
+ self.model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused").eval()
27
+ for params in self.model.parameters():
28
+ params.requires_grad=False
29
+ else:
30
+ self.model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused").train()
31
+ def forward(self,audio):
32
+ batch_embeddings_audio = self.model(**audio)['audio_embeds']
33
+ return batch_embeddings_audio
34
+
35
+ if __name__ == '__main__':
36
+ path_to_audio ="/mnt/hdd/inat2021_ds/sound_train/sounds_mp3/165878447.mp3"
37
+ sample = get_audio_clap(path_to_audio)
38
+ print(sample.keys())
39
+
40
+ sample['input_features'] = torch.concat([sample['input_features'],sample['input_features']],axis=0)
41
+ sample['is_longer'] = torch.concat([sample['is_longer'],sample['is_longer']],axis=0)
42
+ print(sample['input_features'].shape,sample['is_longer'].shape) #torch.Size([2, 4, 1001, 64]), torch.Size([2, 1])
43
+ model = CLAP_audiomodel_withProjection(freeze=False)
44
+ audio_feat = model(sample)
45
+ print(audio_feat.shape) #torch.Size([2, 512])
app.py CHANGED
@@ -1,222 +1,123 @@
1
  """
2
- Search-TTA demo
 
 
 
 
 
 
3
  """
4
 
5
  # ────────────────────────── imports ───────────────────────────────────
6
- import cv2
 
7
  import gradio as gr
8
  import torch
9
- import numpy as np
10
  from PIL import Image
11
- import matplotlib.pyplot as plt
12
- import io
13
- import torchaudio
14
- import spaces # integration with ZeroGPU on hf
15
-
16
- from torchvision import transforms
17
- import open_clip
18
- from clip_vision_per_patch_model import CLIPVisionPerPatchModel
19
- from transformers import ClapAudioModelWithProjection
20
- from transformers import ClapProcessor
21
-
22
- # ────────────────────────── global config & models ────────────────────
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
-
25
- # BioCLIP (ground-image & text encoder)
26
- bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
27
- bio_model = bio_model.to(device).eval()
28
- bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
29
-
30
- # Satellite patch encoder CLIP-L-336 per-patch)
31
- sat_model: CLIPVisionPerPatchModel = (
32
- CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
33
- .to(device)
34
- .eval()
35
- )
36
-
37
- # Sound CLAP model
38
- sound_model: ClapAudioModelWithProjection = (
39
- ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
40
- .to(device)
41
- .eval()
42
- )
43
- sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
44
- SAMPLE_RATE = 48000
45
-
46
- logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
47
- logit_scale = logit_scale.exp()
48
- blur_kernel = (5,5)
49
-
50
- # ────────────────────────── transforms (exact spec) ───────────────────
51
- img_transform = transforms.Compose(
52
- [
53
- transforms.Resize((256, 256)),
54
- transforms.CenterCrop((224, 224)),
55
- transforms.ToTensor(),
56
- transforms.Normalize(
57
- mean=[0.485, 0.456, 0.406],
58
- std=[0.229, 0.224, 0.225],
59
- ),
60
- ]
61
- )
62
 
63
- imo_transform = transforms.Compose(
64
- [
65
- transforms.Resize((336, 336)),
66
- transforms.ToTensor(),
67
- transforms.Normalize(
68
- mean=[0.485, 0.456, 0.406],
69
- std=[0.229, 0.224, 0.225],
70
- ),
71
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
- def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
75
- track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
76
- track = track.mean(axis=0)
77
- track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
78
- output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
79
- return output
80
-
81
- # ────────────────────────── helpers ───────────────────────────────────
82
-
83
- @torch.no_grad()
84
- def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
85
- img = img_transform(img_pil).unsqueeze(0).to(device)
86
- img_embeds, *_ = bio_model(img)
87
- return img_embeds
88
-
89
-
90
- @torch.no_grad()
91
- def _encode_text(text: str) -> torch.Tensor:
92
- toks = bio_tokenizer(text).to(device)
93
- _, txt_embeds, _ = bio_model(text=toks)
94
- return txt_embeds
95
-
96
-
97
- @torch.no_grad()
98
- def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
99
- imo = imo_transform(img_pil).unsqueeze(0).to(device)
100
- imo_embeds = sat_model(imo)
101
- return imo_embeds
102
-
103
-
104
- @torch.no_grad()
105
- def _encode_sound(sound) -> torch.Tensor:
106
- processed_sound = get_audio_clap(sound)
107
- for k in processed_sound.keys():
108
- processed_sound[k] = processed_sound[k].to(device)
109
- unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
110
- sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
111
- return sound_embeds
112
-
113
-
114
- def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
115
- sims = torch.matmul(query, patches.t()) * logit_scale
116
- sims = sims.t().sigmoid()
117
- sims = sims[1:].squeeze() # drop CLS token
118
- side = int(np.sqrt(len(sims)))
119
- sims = sims.reshape(side, side)
120
- return sims.cpu().detach().numpy()
121
-
122
-
123
- def _array_to_pil(arr: np.ndarray) -> Image.Image:
124
- """
125
- Render arr with viridis, automatically stretching its own min→max to 0→1
126
- so that the most-similar patches appear yellow.
127
- """
128
-
129
- # Gausian Smoothing
130
- if blur_kernel != (0,0):
131
- arr = cv2.GaussianBlur(arr, blur_kernel, 0)
132
 
133
- # --- contrast-stretch to local 0-1 range --------------------------
134
- arr_min, arr_max = float(arr.min()), float(arr.max())
135
- if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
136
- arr_scaled = np.zeros_like(arr)
137
- else:
138
- arr_scaled = (arr - arr_min) / (arr_max - arr_min)
139
- # ------------------------------------------------------------------
140
- fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
141
- ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
142
- ax.axis("off")
143
- buf = io.BytesIO()
144
- plt.tight_layout(pad=0)
145
- fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
146
- plt.close(fig)
147
- buf.seek(0)
148
- return Image.open(buf)
149
-
150
- # ────────────────────────── main inference ────────────────────────────
151
- # integration with ZeroGPU on hf
152
- @spaces.GPU
153
  def process(
154
- sat_img: Image.Image,
155
- taxonomy: str,
156
  ground_img: Image.Image | None,
157
- sound: torch.Tensor | None,
158
  ):
159
- if sat_img is None:
160
- return None, None
161
-
162
- patches = _encode_sat(sat_img)
163
-
164
- heat_ground, heat_text, heat_sound = None, None, None
165
 
166
- if ground_img is not None:
167
- q_img = _encode_ground(ground_img)
168
- heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
 
 
 
 
 
169
 
170
- if taxonomy.strip():
171
- q_txt = _encode_text(taxonomy.strip())
172
- heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
173
 
174
- if sound is not None:
175
- q_sound = _encode_sound(sound)
176
- heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
177
 
178
- return heat_ground, heat_text, heat_sound
 
179
 
180
 
181
  # ────────────────────────── Gradio UI ─────────────────────────────────
182
- with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
183
 
184
- with gr.Row():
185
- gr.Markdown(
186
  """
187
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
188
- <div>
189
- <h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
190
- <span></span>
191
- <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
192
- <a href="https://search-tta.github.io">Project Website</a>
193
- </h2>
194
- <span></span>
195
- <h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
196
- </div>
197
- </div>
198
  """
199
- # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
200
-
201
- # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
202
- # <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
203
- # <a href="https://chinchinati.github.io/">Shailesh</a>,
204
- # <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
205
- # <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
206
- # <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
207
- # <a href="https://weihengdai.top">Weiheng Dai</a>,
208
- # <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
209
- # <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
210
- # <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
211
- # <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
212
- # <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
213
- # </h2>
214
- # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
215
- )
216
 
217
  with gr.Row(variant="panel"):
218
-
219
- # LEFT COLUMN (satellite, taxonomy, run)
220
  with gr.Column():
221
  sat_input = gr.Image(
222
  label="Satellite Image",
@@ -228,54 +129,21 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
228
  label="Full Taxonomy Name (optional)",
229
  placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
230
  )
231
-
232
- # ─── NEW: sound input ───────────────────────────
233
- sound_input = gr.Audio(
234
- label="Sound Input (optional)",
235
- sources=["upload"], # or "microphone" / "url" as you prefer
236
- type="filepath", # or "numpy" if you want raw arrays
237
- )
238
- run_btn = gr.Button("Run", variant="primary")
239
-
240
- # RIGHT COLUMN (ground image + two heat-maps)
241
- with gr.Column():
242
  ground_input = gr.Image(
243
  label="Ground-level Image (optional)",
244
  sources=["upload"],
245
  type="pil",
246
  height=320,
247
  )
248
- gr.Markdown("### Heat-map Results")
249
- with gr.Row():
250
- # Separate label and image to avoid overlap
251
- with gr.Column(scale=1, min_width=100):
252
- gr.Markdown("**Ground Image Query**", elem_id="label-ground")
253
- heat_ground_out = gr.Image(
254
- show_label=False,
255
- height=160,
256
- # width=160,
257
- )
258
- with gr.Column(scale=1, min_width=100):
259
- gr.Markdown("**Text Query**", elem_id="label-text")
260
- heat_text_out = gr.Image(
261
- show_label=False,
262
- height=160,
263
- # width=160,
264
- )
265
- with gr.Column(scale=1, min_width=100):
266
- gr.Markdown("**Sound Query**", elem_id="label-sound")
267
- heat_sound_out = gr.Image(
268
- show_label=False,
269
- height=160,
270
- # width=160,
271
- )
272
- # ─── NEW: sound output ─────────────────────────
273
- # sound_output = gr.Audio(
274
- # label="Playback",
275
- # )
276
-
277
 
278
- # EXAMPLES
279
  with gr.Row():
280
  gr.Markdown("### In-Domain Taxonomy")
281
  with gr.Row():
@@ -285,40 +153,34 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
285
  "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
286
  "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
287
  "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
288
- "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
289
  ],
290
  [
291
  "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
292
  "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
293
  "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
294
- "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
295
  ],
296
  [
297
  "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
298
  "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
299
  "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
300
- "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
301
  ],
302
  [
303
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
304
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
305
  "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
306
- "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
307
  ],
308
  [
309
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
310
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
311
  "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
312
- None
313
  ],
314
  ],
315
- inputs=[sat_input, ground_input, taxonomy_input, sound_input],
316
- outputs=[heat_ground_out, heat_text_out, heat_sound_out],
317
- fn=process,
318
  cache_examples=False,
319
  )
320
 
321
- # EXAMPLES
322
  with gr.Row():
323
  gr.Markdown("### Out-Domain Taxonomy")
324
  with gr.Row():
@@ -328,48 +190,45 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
328
  "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
329
  "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
330
  "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
331
- "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
332
  ],
333
  [
334
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
335
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
336
  "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
337
- "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
338
  ],
339
  [
340
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
341
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
342
  "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
343
- None
344
  ],
345
  [
346
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
347
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
348
  "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
349
- None
350
  ],
351
  ],
352
- inputs=[sat_input, ground_input, taxonomy_input, sound_input],
353
- outputs=[heat_ground_out, heat_text_out, heat_sound_out],
354
- fn=process,
355
  cache_examples=False,
356
  )
357
 
358
- # CALLBACK
359
  run_btn.click(
360
  fn=process,
361
- inputs=[sat_input, taxonomy_input, ground_input, sound_input],
362
- outputs=[heat_ground_out, heat_text_out, heat_sound_out],
363
  )
364
 
365
- # Footer to point out to model and data from app page.
366
- gr.Markdown(
367
- """
368
- The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
369
- """
370
- )
371
 
372
- # LAUNCH
373
  if __name__ == "__main__":
 
 
374
  demo.queue(max_size=15)
375
  demo.launch(share=True)
 
1
  """
2
+ Simplified Gradio demo for Search-TTA evaluation.
3
+ This version mirrors the layout of `app_BACKUP.py` but:
4
+ 1. Loads no OpenCLIP / CLAP / Satellite encoders at import-time.
5
+ 2. Keeps only the Satellite and Ground-level image inputs.
6
+ 3. Exposes the high-level wrapper classes `ClipSegTTA` and
7
+ `TestWorker` and calls `TestWorker.run_episode` inside the
8
+ `process` callback.
9
  """
10
 
11
  # ────────────────────────── imports ───────────────────────────────────
12
+ from pathlib import Path
13
+
14
  import gradio as gr
15
  import torch
 
16
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Import configuration & RL / TTA utilities -------------------------------------------------
19
+ # NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.)
20
+ # are available exactly as referenced later in the unchanged snippet.
21
+ from test_parameter import * # noqa: F403, F401 (wild-import is intentional here)
22
+
23
+ from model import PolicyNet # noqa: E402 – after wild import on purpose
24
+ from test_multi_robot_worker import TestWorker # noqa: E402
25
+ # from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA # noqa: E402
26
+ from Taxabind.Taxabind.SatBind.clip_seg_tta import ClipSegTTA
27
+
28
+
29
+ # CHANGE ME!
30
+ currEpisode = 0
31
+
32
+ # Prepare the model
33
+ # device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu')
34
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
35
+ policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
36
+ # script_dir = os.path.dirname(os.path.abspath(__file__))
37
+ script_dir = Path(__file__).resolve().parent
38
+ print("real_script_dir: ", script_dir)
39
+ # checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}')
40
+ checkpoint = torch.load(f'{model_path}/{MODEL_NAME}')
41
+ policy_net.load_state_dict(checkpoint['policy_model'])
42
+ print('Model loaded!')
43
+ # print(next(policy_net.parameters()).device)
44
+
45
+ # Init Taxabind here (only need to init once)
46
+ if TAXABIND_TTA:
47
+ # self.clip_seg_tta = None
48
+ clip_seg_tta = ClipSegTTA(
49
+ img_dir=TAXABIND_IMG_DIR,
50
+ imo_dir=TAXABIND_IMO_DIR,
51
+ json_path=TAXABIND_INAT_JSON_PATH,
52
+ sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
53
+ patch_size=TAXABIND_PATCH_SIZE,
54
+ sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
55
+ sample_index = 0, # Set using 'reset' in worker
56
+ blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
57
+ device=device,
58
+ sat_to_img_ids_json_is_train_dict=False, # for search ds val
59
+ tax_to_filter_val=QUERY_TAX,
60
+ load_model=USE_CLIP_PREDS,
61
+ initial_modality=INITIAL_MODALITY,
62
+ sound_data_path = TAXABIND_SOUND_DATA_PATH,
63
+ sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
64
+ # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
65
+ )
66
+ print("ClipSegTTA Loaded!")
67
+ else:
68
+ clip_seg_tta = None
69
+
70
+ # Define TestWorker
71
+ planner = TestWorker(
72
+ meta_agent_id=0,
73
+ n_agent=1,
74
+ policy_net=policy_net,
75
+ global_step=3,
76
+ device='cuda',
77
+ greedy=True,
78
+ save_image=SAVE_GIFS,
79
+ clip_seg_tta=clip_seg_tta
80
  )
81
 
82
+ # ────────────────────────── Gradio process fn ─────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def process(
85
+ sat_img: Image.Image | None,
 
86
  ground_img: Image.Image | None,
87
+ taxonomy: str | None = None,
88
  ):
89
+ """Callback executed when the user presses **Run** in the UI.
 
 
 
 
 
90
 
91
+ At test-time we simply trigger the RL search episode via
92
+ ``planner.run_episode`` and return its performance metrics.
93
+ The image inputs are currently *not* used directly here but are
94
+ retained to conform to the requested interface.
95
+ """
96
+ # If no satellite image is provided we bail out early.
97
+ if sat_img is None:
98
+ return {"error": "Please provide a satellite image."}
99
 
100
+ # Optionally you may want to reset episode index or make it configurable.
101
+ # For now we hard-code episode 0, mirroring the snippet.
102
+ planner.run_episode(currEpisode)
103
 
104
+ print("planner.perf_metrics: ", planner.perf_metrics)
 
 
105
 
106
+ # Return the collected performance metrics so they can be inspected.
107
+ return planner.perf_metrics
108
 
109
 
110
  # ────────────────────────── Gradio UI ─────────────────────────────────
111
+ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
112
 
113
+ gr.Markdown(
 
114
  """
115
+ # Search-TTA Simplified Demo
116
+ **Satellite ↔ Ground-level Visual Search** via RL Test-Time Adaptation.
 
 
 
 
 
 
 
 
 
117
  """
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  with gr.Row(variant="panel"):
 
 
121
  with gr.Column():
122
  sat_input = gr.Image(
123
  label="Satellite Image",
 
129
  label="Full Taxonomy Name (optional)",
130
  placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
131
  )
 
 
 
 
 
 
 
 
 
 
 
132
  ground_input = gr.Image(
133
  label="Ground-level Image (optional)",
134
  sources=["upload"],
135
  type="pil",
136
  height=320,
137
  )
138
+ run_btn = gr.Button("Run", variant="primary")
139
+
140
+ with gr.Column():
141
+ gr.Markdown("### Episode Metrics")
142
+ metrics_out = gr.JSON(label="Performance Metrics")
143
+
144
+ # Bind callback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # EXAMPLES – copied from original demo (satellite, ground, taxonomy only)
147
  with gr.Row():
148
  gr.Markdown("### In-Domain Taxonomy")
149
  with gr.Row():
 
153
  "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
154
  "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
155
  "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
 
156
  ],
157
  [
158
  "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
159
  "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
160
  "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
 
161
  ],
162
  [
163
  "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
164
  "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
165
  "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
 
166
  ],
167
  [
168
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
169
  "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
170
  "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
 
171
  ],
172
  [
173
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
174
  "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
175
  "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
 
176
  ],
177
  ],
178
+ inputs=[sat_input, ground_input, taxonomy_input],
179
+ outputs=[metrics_out],
180
+ fn=lambda sat, grd, tax: process(sat, grd),
181
  cache_examples=False,
182
  )
183
 
 
184
  with gr.Row():
185
  gr.Markdown("### Out-Domain Taxonomy")
186
  with gr.Row():
 
190
  "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
191
  "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
192
  "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
 
193
  ],
194
  [
195
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
196
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
197
  "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
 
198
  ],
199
  [
200
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
201
  "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
202
  "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
 
203
  ],
204
  [
205
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
206
  "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
207
  "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
 
208
  ],
209
  ],
210
+ inputs=[sat_input, ground_input, taxonomy_input],
211
+ outputs=[metrics_out],
212
+ fn=lambda sat, grd, tax: process(sat, grd),
213
  cache_examples=False,
214
  )
215
 
216
+
217
  run_btn.click(
218
  fn=process,
219
+ inputs=[sat_input, ground_input, taxonomy_input],
220
+ outputs=metrics_out,
221
  )
222
 
223
+ # ────────────────────────── unchanged worker initialisation ───────────
224
+ # NOTE: **Do NOT modify the code below.** It is copied verbatim from the
225
+ # user-provided snippet so that the exact same objects are created.
226
+ # The variables referenced here come from `test_parameter` which we
227
+ # imported with a wildcard earlier.
 
228
 
229
+ # if def main
230
  if __name__ == "__main__":
231
+
232
+ # Finally launch the Gradio interface (queue for concurrency).
233
  demo.queue(max_size=15)
234
  demo.launch(share=True)
app_BACKUP.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Search-TTA demo
3
+ """
4
+
5
+ # ────────────────────────── imports ───────────────────────────────────
6
+ import cv2
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ import io
13
+ import torchaudio
14
+ import spaces # integration with ZeroGPU on hf
15
+
16
+ from torchvision import transforms
17
+ import open_clip
18
+ from clip_vision_per_patch_model import CLIPVisionPerPatchModel
19
+ from transformers import ClapAudioModelWithProjection
20
+ from transformers import ClapProcessor
21
+
22
+ # ────────────────────────── global config & models ────────────────────
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # BioCLIP (ground-image & text encoder)
26
+ bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
27
+ bio_model = bio_model.to(device).eval()
28
+ bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
29
+
30
+ # Satellite patch encoder CLIP-L-336 per-patch)
31
+ sat_model: CLIPVisionPerPatchModel = (
32
+ CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
33
+ .to(device)
34
+ .eval()
35
+ )
36
+
37
+ # Sound CLAP model
38
+ sound_model: ClapAudioModelWithProjection = (
39
+ ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
40
+ .to(device)
41
+ .eval()
42
+ )
43
+ sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
44
+ SAMPLE_RATE = 48000
45
+
46
+ logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
47
+ logit_scale = logit_scale.exp()
48
+ blur_kernel = (5,5)
49
+
50
+ # ────────────────────────── transforms (exact spec) ───────────────────
51
+ img_transform = transforms.Compose(
52
+ [
53
+ transforms.Resize((256, 256)),
54
+ transforms.CenterCrop((224, 224)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(
57
+ mean=[0.485, 0.456, 0.406],
58
+ std=[0.229, 0.224, 0.225],
59
+ ),
60
+ ]
61
+ )
62
+
63
+ imo_transform = transforms.Compose(
64
+ [
65
+ transforms.Resize((336, 336)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(
68
+ mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225],
70
+ ),
71
+ ]
72
+ )
73
+
74
+ def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
75
+ track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
76
+ track = track.mean(axis=0)
77
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
78
+ output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
79
+ return output
80
+
81
+ # ────────────────────────── helpers ───────────────────────────────────
82
+
83
+ @torch.no_grad()
84
+ def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
85
+ img = img_transform(img_pil).unsqueeze(0).to(device)
86
+ img_embeds, *_ = bio_model(img)
87
+ return img_embeds
88
+
89
+
90
+ @torch.no_grad()
91
+ def _encode_text(text: str) -> torch.Tensor:
92
+ toks = bio_tokenizer(text).to(device)
93
+ _, txt_embeds, _ = bio_model(text=toks)
94
+ return txt_embeds
95
+
96
+
97
+ @torch.no_grad()
98
+ def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
99
+ imo = imo_transform(img_pil).unsqueeze(0).to(device)
100
+ imo_embeds = sat_model(imo)
101
+ return imo_embeds
102
+
103
+
104
+ @torch.no_grad()
105
+ def _encode_sound(sound) -> torch.Tensor:
106
+ processed_sound = get_audio_clap(sound)
107
+ for k in processed_sound.keys():
108
+ processed_sound[k] = processed_sound[k].to(device)
109
+ unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
110
+ sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
111
+ return sound_embeds
112
+
113
+
114
+ def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
115
+ sims = torch.matmul(query, patches.t()) * logit_scale
116
+ sims = sims.t().sigmoid()
117
+ sims = sims[1:].squeeze() # drop CLS token
118
+ side = int(np.sqrt(len(sims)))
119
+ sims = sims.reshape(side, side)
120
+ return sims.cpu().detach().numpy()
121
+
122
+
123
+ def _array_to_pil(arr: np.ndarray) -> Image.Image:
124
+ """
125
+ Render arr with viridis, automatically stretching its own min→max to 0→1
126
+ so that the most-similar patches appear yellow.
127
+ """
128
+
129
+ # Gausian Smoothing
130
+ if blur_kernel != (0,0):
131
+ arr = cv2.GaussianBlur(arr, blur_kernel, 0)
132
+
133
+ # --- contrast-stretch to local 0-1 range --------------------------
134
+ arr_min, arr_max = float(arr.min()), float(arr.max())
135
+ if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
136
+ arr_scaled = np.zeros_like(arr)
137
+ else:
138
+ arr_scaled = (arr - arr_min) / (arr_max - arr_min)
139
+ # ------------------------------------------------------------------
140
+ fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
141
+ ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
142
+ ax.axis("off")
143
+ buf = io.BytesIO()
144
+ plt.tight_layout(pad=0)
145
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
146
+ plt.close(fig)
147
+ buf.seek(0)
148
+ return Image.open(buf)
149
+
150
+ # ────────────────────────── main inference ────────────────────────────
151
+ # integration with ZeroGPU on hf
152
+ @spaces.GPU
153
+ def process(
154
+ sat_img: Image.Image,
155
+ taxonomy: str,
156
+ ground_img: Image.Image | None,
157
+ sound: torch.Tensor | None,
158
+ ):
159
+ if sat_img is None:
160
+ return None, None
161
+
162
+ patches = _encode_sat(sat_img)
163
+
164
+ heat_ground, heat_text, heat_sound = None, None, None
165
+
166
+ if ground_img is not None:
167
+ q_img = _encode_ground(ground_img)
168
+ heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
169
+
170
+ if taxonomy.strip():
171
+ q_txt = _encode_text(taxonomy.strip())
172
+ heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
173
+
174
+ if sound is not None:
175
+ q_sound = _encode_sound(sound)
176
+ heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
177
+
178
+ return heat_ground, heat_text, heat_sound
179
+
180
+
181
+ # ────────────────────────── Gradio UI ─────────────────────────────────
182
+ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
183
+
184
+ with gr.Row():
185
+ gr.Markdown(
186
+ """
187
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
188
+ <div>
189
+ <h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
190
+ <span></span>
191
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
192
+ <a href="https://search-tta.github.io">Project Website</a>
193
+ </h2>
194
+ <span></span>
195
+ <h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
196
+ </div>
197
+ </div>
198
+ """
199
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
200
+
201
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
202
+ # <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
203
+ # <a href="https://chinchinati.github.io/">Shailesh</a>,
204
+ # <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
205
+ # <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
206
+ # <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
207
+ # <a href="https://weihengdai.top">Weiheng Dai</a>,
208
+ # <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
209
+ # <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
210
+ # <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
211
+ # <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
212
+ # <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
213
+ # </h2>
214
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
215
+ )
216
+
217
+ with gr.Row(variant="panel"):
218
+
219
+ # LEFT COLUMN (satellite, taxonomy, run)
220
+ with gr.Column():
221
+ sat_input = gr.Image(
222
+ label="Satellite Image",
223
+ sources=["upload"],
224
+ type="pil",
225
+ height=320,
226
+ )
227
+ taxonomy_input = gr.Textbox(
228
+ label="Full Taxonomy Name (optional)",
229
+ placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
230
+ )
231
+
232
+ # ─── NEW: sound input ───────────────────────────
233
+ sound_input = gr.Audio(
234
+ label="Sound Input (optional)",
235
+ sources=["upload"], # or "microphone" / "url" as you prefer
236
+ type="filepath", # or "numpy" if you want raw arrays
237
+ )
238
+ run_btn = gr.Button("Run", variant="primary")
239
+
240
+ # RIGHT COLUMN (ground image + two heat-maps)
241
+ with gr.Column():
242
+ ground_input = gr.Image(
243
+ label="Ground-level Image (optional)",
244
+ sources=["upload"],
245
+ type="pil",
246
+ height=320,
247
+ )
248
+ gr.Markdown("### Heat-map Results")
249
+ with gr.Row():
250
+ # Separate label and image to avoid overlap
251
+ with gr.Column(scale=1, min_width=100):
252
+ gr.Markdown("**Ground Image Query**", elem_id="label-ground")
253
+ heat_ground_out = gr.Image(
254
+ show_label=False,
255
+ height=160,
256
+ # width=160,
257
+ )
258
+ with gr.Column(scale=1, min_width=100):
259
+ gr.Markdown("**Text Query**", elem_id="label-text")
260
+ heat_text_out = gr.Image(
261
+ show_label=False,
262
+ height=160,
263
+ # width=160,
264
+ )
265
+ with gr.Column(scale=1, min_width=100):
266
+ gr.Markdown("**Sound Query**", elem_id="label-sound")
267
+ heat_sound_out = gr.Image(
268
+ show_label=False,
269
+ height=160,
270
+ # width=160,
271
+ )
272
+ # ─── NEW: sound output ─────────────────────────
273
+ # sound_output = gr.Audio(
274
+ # label="Playback",
275
+ # )
276
+
277
+
278
+ # EXAMPLES
279
+ with gr.Row():
280
+ gr.Markdown("### In-Domain Taxonomy")
281
+ with gr.Row():
282
+ gr.Examples(
283
+ examples=[
284
+ [
285
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
286
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
287
+ "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
288
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
289
+ ],
290
+ [
291
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
292
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
293
+ "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
294
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
295
+ ],
296
+ [
297
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
298
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
299
+ "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
300
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
301
+ ],
302
+ [
303
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
304
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
305
+ "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
306
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
307
+ ],
308
+ [
309
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
310
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
311
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
312
+ None
313
+ ],
314
+ ],
315
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
316
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
317
+ fn=process,
318
+ cache_examples=False,
319
+ )
320
+
321
+ # EXAMPLES
322
+ with gr.Row():
323
+ gr.Markdown("### Out-Domain Taxonomy")
324
+ with gr.Row():
325
+ gr.Examples(
326
+ examples=[
327
+ [
328
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
329
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
330
+ "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
331
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
332
+ ],
333
+ [
334
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
335
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
336
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
337
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
338
+ ],
339
+ [
340
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
341
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
342
+ "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
343
+ None
344
+ ],
345
+ [
346
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
347
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
348
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
349
+ None
350
+ ],
351
+ ],
352
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
353
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
354
+ fn=process,
355
+ cache_examples=False,
356
+ )
357
+
358
+ # CALLBACK
359
+ run_btn.click(
360
+ fn=process,
361
+ inputs=[sat_input, taxonomy_input, ground_input, sound_input],
362
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
363
+ )
364
+
365
+ # Footer to point out to model and data from app page.
366
+ gr.Markdown(
367
+ """
368
+ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
369
+ """
370
+ )
371
+
372
+ # LAUNCH
373
+ if __name__ == "__main__":
374
+ demo.queue(max_size=15)
375
+ demo.launch(share=True)
env.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: env.py
3
+ #
4
+ # - Reads and processes training and test maps (E.g. DungeonMaps)
5
+ # - Processes rewards, new frontiers given action
6
+ # - Updates a graph representation of environment for input into network
7
+ #######################################################################
8
+
9
+ import sys
10
+
11
+ import cv2
12
+ from matplotlib.colors import LogNorm, PowerNorm
13
+ if sys.modules['TRAINING']:
14
+ from parameter import *
15
+ else:
16
+ from test_parameter import *
17
+
18
+ import copy
19
+ import pandas as pd
20
+ import rasterio
21
+ from skimage import io
22
+ import matplotlib.pyplot as plt
23
+ import os
24
+ from skimage.measure import block_reduce
25
+ from sensor import *
26
+ from graph_generator import *
27
+ from node import *
28
+ from scipy.ndimage import label, find_objects
29
+ import matplotlib.image as mpimg
30
+ from matplotlib.colors import Normalize
31
+
32
+ # import matplotlib
33
+ # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
34
+
35
+
36
+ class Env():
37
+ def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None):
38
+ self.n_agent = n_agent
39
+ self.test = test
40
+ self.map_dir = GRIDMAP_SET_DIR
41
+
42
+ # Import environment gridmap
43
+ self.map_list = os.listdir(self.map_dir)
44
+ self.map_list.sort(reverse=True)
45
+
46
+ # NEW: Import segmentation utility map
47
+ self.seg_dir = MASK_SET_DIR
48
+ self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], []
49
+ self.segmentation_mask_list = os.listdir(self.seg_dir)
50
+ self.segmentation_mask_list.sort(reverse=True)
51
+
52
+ # Import target maps (if relevant)
53
+ if TARGETS_SET_DIR != "":
54
+ self.targets_map_list = os.listdir(TARGETS_SET_DIR)
55
+ self.targets_map_list.sort(reverse=True)
56
+
57
+ # # NEW: Find common files in both directories
58
+ # if TARGETS_SET_DIR != "":
59
+ # common_files = [file for file in self.map_list if file in self.segmentation_mask_list and file in self.targets_map_list]
60
+ # else:
61
+ # common_files = [file for file in self.map_list if file in self.segmentation_mask_list]
62
+ self.map_index = map_index % len(self.map_list)
63
+ if mask_index is not None:
64
+ self.mask_index = mask_index % len(self.segmentation_mask_list)
65
+ else:
66
+ self.mask_index = map_index % len(self.segmentation_mask_list)
67
+ # self.common_map_file = common_files[self.map_index]
68
+ # print("self.common_map_file: ", self.common_map_file)
69
+
70
+ # Import ground truth and segmentation mask
71
+ self.ground_truth, self.map_start_position = self.import_ground_truth(
72
+ os.path.join(self.map_dir, self.map_list[self.map_index]))# self.common_map_file))
73
+ self.ground_truth_size = np.shape(self.ground_truth) # (480, 640)
74
+ self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
75
+ self.downsampled_belief = None
76
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
77
+ self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
78
+
79
+ # Import segmentation mask
80
+ mask_filename = self.segmentation_mask_list[self.mask_index]
81
+ self.segmentation_mask = self.import_segmentation_mask(
82
+ os.path.join(self.seg_dir, mask_filename))# self.common_map_file))
83
+ # print("mask_filename: ", mask_filename)
84
+
85
+ # Overwrite target positions if directory specified
86
+ self.gt_segmentation_mask = None
87
+ if self.test and TARGETS_SET_DIR != "":
88
+ self.gt_segmentation_mask = self.import_segmentation_mask(
89
+ os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index])) # UNUSED - self.common_map_file))
90
+ # print("target_positions: ", self.target_positions)
91
+ # print("np.unique(self.segmentation_mask): ", np.unique(self.segmentation_mask))
92
+
93
+ self.segmentation_info_mask = None
94
+ self.gt_segmentation_info_mask = None
95
+ self.segmentation_info_mask_unnormalized = None
96
+ self.filtered_seg_info_mask = None
97
+ self.num_targets_found = 0
98
+ self.num_new_targets_found = 0
99
+
100
+ # # Link score masks to raw image files
101
+ # csv_file = pd.read_csv(RAW_IMG_PATH_DICT, header=None)
102
+ # img_score_paths = csv_file.iloc[:,2].tolist()
103
+ # raw_img_paths = csv_file.iloc[:,0].tolist()
104
+ # self.score_to_img_dict = {os.path.basename(img_score_path): raw_img_path for img_score_path, raw_img_path in zip(img_score_paths, raw_img_paths)}
105
+
106
+ self.resolution = 4
107
+ self.sensor_range = SENSOR_RANGE
108
+ self.explored_rate = 0
109
+ self.targets_found_rate = 0
110
+ self.info_gain = 0
111
+ self.total_info = 0
112
+
113
+ self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot)
114
+ self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None
115
+
116
+ self.frontiers = None
117
+
118
+ self.start_positions = []
119
+ self.begin(self.map_start_position)
120
+
121
+ self.plot = plot
122
+ self.frame_files = []
123
+
124
+ def find_index_from_coords(self, position):
125
+ index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1))
126
+ return index
127
+
128
+ def begin(self, start_position):
129
+ # self.robot_belief = self.update_robot_belief(robot_position, self.sensor_range, self.robot_belief,
130
+ # self.ground_truth)
131
+ self.robot_belief = self.ground_truth
132
+
133
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution),
134
+ func=np.min)
135
+ self.frontiers = self.find_frontier()
136
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
137
+
138
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph(
139
+ self.robot_belief, self.frontiers)
140
+
141
+ # Find non-conflicting start positions
142
+ if FIX_START_POSITION:
143
+ coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT)
144
+ coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH)
145
+ self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)] # bottom-left corner
146
+ else:
147
+ nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position)
148
+ itr = 0
149
+ for i in range(self.n_agent):
150
+ if i == 0 or len(nearby_coords) == 0:
151
+ self.start_positions.append(start_position)
152
+ else:
153
+ idx = min(itr, len(nearby_coords)-1)
154
+ self.start_positions.append(nearby_coords[idx])
155
+ itr += 1
156
+
157
+ for i in range(len(self.start_positions)):
158
+ self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])]
159
+ self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief,
160
+ self.ground_truth)
161
+
162
+ for start_position in self.start_positions:
163
+ self.graph_generator.route_node.append(start_position)
164
+
165
+ # Info map from ground truth
166
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
167
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
168
+ self.segmentation_info_mask = np.zeros((len(self.node_coords), 1))
169
+ self.gt_segmentation_info_mask = np.zeros((len(self.node_coords), 1))
170
+ for i, node_coord in enumerate(self.node_coords):
171
+ max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
172
+ min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0)
173
+ max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
174
+ min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0)
175
+
176
+ # if np.any(self.segmentation_mask[min_y:max_y, min_x:max_x] == 255):
177
+ # self.segmentation_info_mask[i] = 1.0
178
+ # else:
179
+ # self.segmentation_info_mask[i] = 0.0
180
+ # self.segmentation_info_mask[i] = np.mean(self.segmentation_mask[min_y:max_y, min_x:max_x])
181
+ # self.segmentation_info_mask[i] = np.max(self.segmentation_mask[min_y:max_y, min_x:max_x])
182
+ if TARGETS_SET_DIR == "": # If targets combined with segmentation mask
183
+ exclude = {208} # Exclude target positions
184
+ else:
185
+ exclude = {}
186
+ self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
187
+ if self.gt_segmentation_mask is not None:
188
+ self.gt_segmentation_info_mask[i] = max(x for x in self.gt_segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
189
+ # print("np.unique(self.segmentation_info_mask): ", np.unique(self.segmentation_info_mask))
190
+ self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask)
191
+
192
+ # In case targets found at beginning...
193
+ done, num_targets_found = self.check_done()
194
+ self.num_targets_found = num_targets_found
195
+
196
+
197
+ def multi_robot_step(self, next_position_list, dist_list, travel_dist_list):
198
+ temp_frontiers = copy.deepcopy(self.frontiers)
199
+ reward_list = []
200
+ for dist, robot_position in zip(dist_list, next_position_list):
201
+ self.graph_generator.route_node.append(robot_position)
202
+ next_node_index = self.find_index_from_coords(robot_position)
203
+ self.graph_generator.nodes_list[next_node_index].set_visited()
204
+ # self.robot_belief = self.update_robot_belief(robot_position, self.sensor_range, self.robot_belief,
205
+ # self.ground_truth)
206
+ self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief,
207
+ self.ground_truth)
208
+ self.robot_belief = self.ground_truth
209
+
210
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(),
211
+ block_size=(self.resolution, self.resolution),
212
+ func=np.min)
213
+
214
+ frontiers = self.find_frontier()
215
+ #num_observed_frontiers = self.calculate_num_observed_frontiers(temp_frontiers, frontiers)
216
+ #temp_frontiers = frontiers
217
+
218
+ num_observed_frontiers = self.node_utility[next_node_index]
219
+
220
+ # individual_reward = num_observed_frontiers / 50 - dist / 64
221
+ individual_reward = -dist / 32 # 64
222
+
223
+ info_gain_reward = 0
224
+ robot_position_idx = self.find_index_from_coords(robot_position)
225
+ # if self.segmentation_info_mask[robot_position_idx] == 1.0 and self.guidepost[robot_position_idx] == 0.0:
226
+ # # print("High Info (Unvisited)")
227
+ # info_gain_reward = (HIGH_INFO_REWARD_RATIO * 1.5)
228
+ # elif self.segmentation_info_mask[robot_position_idx] == 0.0 and self.guidepost[robot_position_idx] == 0.0:
229
+ # # print("Low Info (Unvisited)")
230
+ # info_gain_reward = ((1-HIGH_INFO_REWARD_RATIO) * 1.5)
231
+ info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5
232
+ if self.guidepost[robot_position_idx] == 0.0:
233
+ info_gain_reward += 0.2
234
+ # print("info_gain_reward: ", info_gain_reward)
235
+ individual_reward += info_gain_reward
236
+
237
+ # print("dist / 64: ", dist / 64)
238
+ # print("info gain reward: ", info_gain_reward)
239
+
240
+ reward_list.append(individual_reward)
241
+
242
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers)
243
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
244
+
245
+ self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
246
+ self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1)
247
+
248
+ num_observed_frontiers = self.calculate_num_observed_frontiers(self.frontiers, frontiers)
249
+ self.frontiers = frontiers
250
+ self.explored_rate = self.evaluate_exploration_rate()
251
+
252
+ done, num_targets_found = self.check_done()
253
+ self.num_new_targets_found = num_targets_found - self.num_targets_found
254
+ # #team_reward = sum(reward_list) / len(reward_list)
255
+ # # team_reward = num_observed_frontiers / 50
256
+ # team_reward = self.num_new_targets_found * 5.0
257
+ team_reward = 0.0
258
+ # # print("target found reward: ", self.num_new_targets_found * 5.0)
259
+
260
+ self.num_targets_found = num_targets_found
261
+ self.targets_found_rate = self.evaluate_targets_found_rate()
262
+ self.info_gain, self.total_info = self.evaluate_info_gain()
263
+
264
+ if done:
265
+ # team_reward += np.sum(self.robot_belief == 255) / sum(travel_dist_list)
266
+ team_reward += 40 # 20
267
+ for i in range(len(reward_list)):
268
+ reward_list[i] += team_reward
269
+
270
+ return reward_list, done
271
+
272
+
273
+ def import_ground_truth(self, map_index):
274
+ # occupied 1, free 255, unexplored 127
275
+
276
+ try:
277
+ # ground_truth = (io.imread(map_index, 1) * 255).astype(int)
278
+ ground_truth = (io.imread(map_index, 1)).astype(int)
279
+ if np.all(ground_truth == 0):
280
+ ground_truth = (io.imread(map_index, 1) * 255).astype(int)
281
+ except:
282
+ new_map_index = self.map_dir + '/' + self.map_list[0]
283
+ ground_truth = (io.imread(new_map_index, 1)).astype(int)
284
+ print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index))
285
+
286
+ robot_location = np.nonzero(ground_truth == 208)
287
+
288
+ # print("robot_location: ", robot_location)
289
+ # print("np.array(robot_location)[1, 127]: ", np.array(robot_location)[1, 127])
290
+
291
+ robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]])
292
+ ground_truth = (ground_truth > 150)
293
+ ground_truth = ground_truth * 254 + 1
294
+ return ground_truth, robot_location
295
+
296
+
297
+ def import_segmentation_mask(self, map_index):
298
+ # occupied 1, free 255, unexplored 127
299
+
300
+ # mask = (io.imread(map_index, 1) * 255).astype(int) # NOTE: Cannot work well with seg mask self-generated
301
+ mask = cv2.imread(map_index).astype(int)
302
+ # print("np.unique(segmentation_mask): ", np.unique(mask))
303
+
304
+ # NOTE: Could contain mutiple start positions
305
+ # target_position = np.nonzero(mask == 208)
306
+ # target_positions = self.find_target_locations(mask)
307
+
308
+ # target_position = np.array([np.array(target_position)[1, 127], np.array(target_position)[0, 127]])
309
+ return mask #, target_positions
310
+
311
+
312
+ def find_target_locations(self, image_array, grey_value=208):
313
+ # Load the image
314
+ # image = Image.open(image_path)
315
+ # image_array = np.array(image)
316
+
317
+ # Identify pixels equal to the grey value
318
+ grey_pixels = np.where(image_array == grey_value)
319
+
320
+ # Create a binary array where grey pixels are marked as True
321
+ binary_array = np.zeros_like(image_array, dtype=bool)
322
+ binary_array[grey_pixels] = True
323
+
324
+ # Label connected components
325
+ labeled_array, num_features = label(binary_array)
326
+
327
+ # Find objects returns slices for each connected component
328
+ slices = find_objects(labeled_array)
329
+
330
+ # Calculate the center of each box
331
+ centers = []
332
+ for slice in slices:
333
+ row_center = (slice[0].start + slice[0].stop - 1) // 2
334
+ col_center = (slice[1].start + slice[1].stop - 1) // 2
335
+ centers.append((col_center, row_center)) # (y,x)
336
+
337
+ return centers
338
+
339
+ def free_cells(self):
340
+ index = np.where(self.ground_truth == 255)
341
+ free = np.asarray([index[1], index[0]]).T
342
+ return free
343
+
344
+ def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth):
345
+ robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth)
346
+ return robot_belief
347
+
348
+
349
+ def check_done(self, robot_id=0):
350
+ """ All agnets to have explored most of the env map """
351
+ done = False
352
+ # for idx in range(self.n_agent):
353
+ # if np.sum(self.ground_truth == 255) - np.sum(self.all_robot_belief[idx][idx] == 255) > 40:
354
+ # done = False
355
+
356
+ # NEW: ADDITIONAL VLM SEARCH CRITERIA
357
+ num_targets_found = 0
358
+ self.target_found_idxs = []
359
+ for i, target in enumerate(self.target_positions):
360
+ if self.coverage_belief[target[1], target[0]] == 255: # 255:
361
+ num_targets_found += 1
362
+ self.target_found_idxs.append(i)
363
+ # free_cells_mask = self.all_robot_belief[robot_id][robot_id] == 255
364
+ # filtered_segmentation_mask = np.where(free_cells_mask, self.segmentation_mask, 0)
365
+ # targets = self.find_target_locations(filtered_segmentation_mask)
366
+ # print("num_targets_found: ", num_targets_found)
367
+
368
+ if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions):
369
+ done = True
370
+ if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99:
371
+ done = True
372
+
373
+ return done, num_targets_found
374
+
375
+
376
+ def calculate_num_observed_frontiers(self, old_frontiers, frontiers):
377
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
378
+ pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
379
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
380
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
381
+ delta_num = pre_frontiers_num - frontiers_num
382
+
383
+ return delta_num
384
+
385
+ def calculate_reward(self, dist, frontiers):
386
+ reward = 0
387
+ reward -= dist / 64
388
+
389
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
390
+ pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j
391
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
392
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
393
+ delta_num = pre_frontiers_num - frontiers_num
394
+
395
+ reward += delta_num / 50
396
+
397
+ return reward
398
+
399
+ def evaluate_exploration_rate(self):
400
+ # rate = np.sum(self.robot_belief == 255) / np.sum(self.ground_truth == 255)
401
+ rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255)
402
+ return rate
403
+
404
+ def evaluate_targets_found_rate(self):
405
+ if len(self.target_positions) == 0:
406
+ return 0
407
+ else:
408
+ rate = self.num_targets_found / len(self.target_positions)
409
+ return rate
410
+
411
+ def evaluate_info_gain(self):
412
+ # print("self.segmentation_mask.shape: ", self.segmentation_mask.shape)
413
+ # coverage_belief = (self.coverage_belief == 255)
414
+ # print("coverage_belief.shape: ", coverage_belief.shape)
415
+ # print("np.unique(coverage_belief): ", np.unique(coverage_belief))
416
+ # print("np.count_nonzero(coverage_belief): ", np.count_nonzero(coverage_belief))
417
+ # print("np.count_zero(coverage_belief): ", coverage_belief.size - np.count_nonzero(coverage_belief))
418
+ # print("self.segmentation_mask[self.coverage_belief == 255].shape: ", self.segmentation_mask[self.coverage_belief == 255].shape)
419
+ if self.test and TARGETS_SET_DIR != "":
420
+ info_gained = np.sum(self.gt_segmentation_mask[self.coverage_belief == 255]) / 100.0
421
+ total_info = np.sum(self.gt_segmentation_mask) / 100.0
422
+ else:
423
+ info_gained = np.sum(self.segmentation_mask[self.coverage_belief == 255]) / 100.0
424
+ total_info = np.sum(self.segmentation_mask) / 100.0
425
+ return info_gained, total_info
426
+
427
+ def calculate_new_free_area(self):
428
+ old_free_area = self.old_robot_belief == 255
429
+ current_free_area = self.robot_belief == 255
430
+
431
+ new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255
432
+
433
+ return new_free_area, np.sum(old_free_area)
434
+
435
+ def calculate_dist_path(self, path):
436
+ dist = 0
437
+ start = path[0]
438
+ end = path[-1]
439
+ for index in path:
440
+ if index == end:
441
+ break
442
+ dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index])
443
+ start = index
444
+ return dist
445
+
446
+ def find_frontier(self):
447
+ y_len = self.downsampled_belief.shape[0]
448
+ x_len = self.downsampled_belief.shape[1]
449
+ mapping = self.downsampled_belief.copy()
450
+ belief = self.downsampled_belief.copy()
451
+ # 0-1 unknown area map
452
+ mapping = (mapping == 127) * 1
453
+ mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0)
454
+ fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \
455
+ mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:,
456
+ 2:] + \
457
+ mapping[:y_len][:, :x_len]
458
+ ind_free = np.where(belief.ravel(order='F') == 255)[0]
459
+ ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0]
460
+ ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0]
461
+ ind_fron = np.intersect1d(ind_fron_1, ind_fron_2)
462
+ ind_to = np.intersect1d(ind_free, ind_fron)
463
+
464
+ map_x = x_len
465
+ map_y = y_len
466
+ x = np.linspace(0, map_x - 1, map_x)
467
+ y = np.linspace(0, map_y - 1, map_y)
468
+ t1, t2 = np.meshgrid(x, y)
469
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
470
+
471
+ f = points[ind_to]
472
+ f = f.astype(int)
473
+
474
+ f = f * self.resolution
475
+
476
+ return f
477
+
478
+ def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None, colormap_mid_val=None):
479
+
480
+ # # TEMP
481
+ # if TAXABIND_TTA:
482
+ # # Save self.segmentation_info_mask as .npy file in gifs_path
483
+ # side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
484
+ # mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
485
+ # np.save(os.path.join(path, f"seg_mask_step{step}.npy"), mask_viz)
486
+
487
+ plt.switch_backend('agg')
488
+ # plt.ion()
489
+ plt.cla()
490
+ color_list = ["r", "g", "c", "m", "y", "k"]
491
+
492
+ if TARGETS_SET_DIR == "" and not TAXABIND_TTA:
493
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
494
+ else:
495
+ fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(12, 8))
496
+
497
+ ### Fig1: Environment ###
498
+ msk_name = ""
499
+ if TAXABIND_TTA:
500
+ image = mpimg.imread(sat_path_override)
501
+ msk_name = msk_name_override
502
+ # else:
503
+ # plt.imshow(self.robot_belief, cmap='gray')
504
+ # ax1.imshow(self.coverage_belief, cmap='gray')
505
+ # image = mpimg.imread("Maps/real_maps/real/4259_masked_img_0.jpg")
506
+ # msk_name = self.map_list[self.map_index]
507
+ # raw_img_path = self.score_to_img_dict[msk_name]
508
+ # if "flair" in raw_img_path:
509
+ # with rasterio.open(raw_img_path) as src_img:
510
+ # image = src_img.read([1,2,3])
511
+ # image = np.transpose(image, (1, 2, 0))
512
+ # else:
513
+ # image = mpimg.imread(raw_img_path)
514
+
515
+
516
+ ### Fig1: Environment ###
517
+ ax = ax1 # if TAXABIND_TTA else ax1
518
+ ax.imshow(image)
519
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
520
+ ax.set_title("Image")
521
+ # if VIZ_GRAPH_EDGES:
522
+ # for i in range(len(self.graph_generator.x)):
523
+ # ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
524
+ # # ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.node_utility, zorder=5)
525
+ # ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.segmentation_info_mask, zorder=5)
526
+ # ax.scatter(self.frontiers[:, 0], self.frontiers[:, 1], c='r', s=2, zorder=3)
527
+ for i, route in enumerate(robots_route):
528
+ robot_marker_color = color_list[i % len(color_list)]
529
+ xPoints = route[0]
530
+ yPoints = route[1]
531
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
532
+ # ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
533
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
534
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
535
+
536
+ # Sensor range
537
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
538
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
539
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
540
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
541
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
542
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
543
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
544
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
545
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
546
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
547
+
548
+ ### Fig2: Graph ###
549
+ ax = ax4 if TAXABIND_TTA else ax1
550
+ # ax.imshow(image)
551
+ ax.imshow(self.coverage_belief, cmap='gray')
552
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
553
+ ax.set_title("Information Graph")
554
+ if VIZ_GRAPH_EDGES:
555
+ for i in range(len(self.graph_generator.x)):
556
+ ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
557
+ # ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.node_utility, zorder=5)
558
+ # ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.segmentation_info_mask, zorder=5)
559
+ # filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
560
+ ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8)
561
+ # ax.scatter(self.frontiers[:, 0], self.frontiers[:, 1], c='r', s=2, zorder=3)
562
+
563
+ for i, route in enumerate(robots_route):
564
+ robot_marker_color = color_list[i % len(color_list)]
565
+ xPoints = route[0]
566
+ yPoints = route[1]
567
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
568
+ # ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
569
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
570
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
571
+
572
+ # Sensor range
573
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
574
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
575
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
576
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
577
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
578
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
579
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
580
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
581
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
582
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
583
+
584
+ # Plot target positions
585
+ for target in self.target_positions:
586
+ if self.coverage_belief[target[1], target[0]] == 255:
587
+ # ax.plot(target[0], target[1], 'go', markersize=8, zorder=99)
588
+ ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
589
+ else:
590
+ # ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99)
591
+ ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
592
+
593
+ # ax.pause(0.1)
594
+
595
+ ### Fig3: Segmentation Mask ###
596
+ ax = ax5 if TAXABIND_TTA else ax2
597
+ if TAXABIND_TTA and USE_CLIP_PREDS:
598
+ side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
599
+ mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
600
+ scale_y = math.ceil(self.ground_truth_size[1] / side_dim)
601
+ scale_x = math.ceil(self.ground_truth_size[0] / side_dim)
602
+ upscaled_mask_viz = np.kron(mask_viz, np.ones((scale_y, scale_x))) # Integer scaling only
603
+ upscaled_mask_viz = upscaled_mask_viz[:self.ground_truth_size[1], :self.ground_truth_size[0]]
604
+ im = ax.imshow(upscaled_mask_viz, cmap="viridis")
605
+ ax.axis("off")
606
+ else:
607
+ im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
608
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
609
+ ax.set_title(f"Predicted Mask (Normalized)")
610
+ for i, route in enumerate(robots_route):
611
+ robot_marker_color = color_list[i % len(color_list)]
612
+ xPoints = route[0]
613
+ yPoints = route[1]
614
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
615
+ # ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
616
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
617
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
618
+
619
+ # Sensor range
620
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
621
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
622
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
623
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
624
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
625
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
626
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
627
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
628
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
629
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
630
+
631
+ # Add a colorbar
632
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
633
+ cbar.set_label("Normalized Probs")
634
+
635
+ # ax.pause(0.1)
636
+
637
+ ### Fig4: Segmentation Mask ###
638
+ if TAXABIND_TTA and USE_CLIP_PREDS:
639
+ ax = ax6
640
+ side_dim = int(np.sqrt(self.segmentation_info_mask_unnormalized.shape[0]))
641
+ mask_viz = self.segmentation_info_mask_unnormalized.squeeze().reshape((side_dim, side_dim)).T
642
+ scale_y = math.ceil(self.ground_truth_size[1] / side_dim)
643
+ scale_x = math.ceil(self.ground_truth_size[0] / side_dim)
644
+ upscaled_mask_viz = np.kron(mask_viz, np.ones((scale_y, scale_x))) # Integer scaling only
645
+ upscaled_mask_viz = upscaled_mask_viz[:self.ground_truth_size[1], :self.ground_truth_size[0]]
646
+
647
+ max_val = 0.15 # TO CHANGE
648
+ mid_val = colormap_mid_val if colormap_mid_val is not None else 0.05
649
+ # mid_val = np.max(self.segmentation_info_mask_unnormalized)
650
+ norm = CustomNorm(vmin=0.0, vmax=max_val, mid=mid_val, lower_portion=0.8)
651
+ im = ax.imshow(upscaled_mask_viz, cmap="viridis", norm=norm) # norm=LogNorm(vmin=0.01, vmax=0.1))
652
+ # norm = PowerNorm(gamma=0.25, vmin=0.01, vmax=0.2)
653
+ # norm=LogNorm(vmin=0.01, vmax=0.2)
654
+ im = ax.imshow(upscaled_mask_viz, cmap="viridis", norm=norm) # norm=LogNorm(vmin=0.01, vmax=0.1))
655
+ ax.axis("off")
656
+ # else:
657
+ # im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
658
+ # ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
659
+ ax.set_title(f"Predicted Mask (Unnormalized)")
660
+ for i, route in enumerate(robots_route):
661
+ robot_marker_color = color_list[i % len(color_list)]
662
+ xPoints = route[0]
663
+ yPoints = route[1]
664
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
665
+ # ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
666
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
667
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
668
+
669
+ # Sensor range
670
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
671
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
672
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
673
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
674
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
675
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
676
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
677
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
678
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
679
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
680
+
681
+ # Add a colorbar
682
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
683
+ if TAXABIND_TTA and USE_CLIP_PREDS:
684
+ cbar.set_ticks([0.0, mid_val, max_val])
685
+ cbar.set_label("Probs (Scaled by expectation)")
686
+
687
+
688
+ # Fog5: GT Mask
689
+ if TARGETS_SET_DIR != "":
690
+ ax = ax2
691
+ im = ax.imshow(self.gt_segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
692
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
693
+ ax.set_title(f"Ground Truth Mask")
694
+ for i, route in enumerate(robots_route):
695
+ robot_marker_color = color_list[i % len(color_list)]
696
+ xPoints = route[0]
697
+ yPoints = route[1]
698
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
699
+ # ax.plot(xPoints[-1], yPoints[-1], 'mo', markersize=8, zorder=10)
700
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
701
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
702
+
703
+ # Sensor range
704
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
705
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
706
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
707
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
708
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
709
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
710
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
711
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
712
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
713
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
714
+
715
+ # Add a colorbar
716
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
717
+ cbar.set_label("Normalized Mask Value")
718
+
719
+ # ax4.pause(0.1)
720
+
721
+
722
+ ### Fig6: Segmentation Mask (GT) ###
723
+ if TAXABIND_TTA:
724
+ ax = ax3
725
+ image = mpimg.imread(img_path_override)
726
+ ax.imshow(image)
727
+ ax.set_title("Ground Image")
728
+ ax.axis("off")
729
+
730
+
731
+ sound_id = sound_id_override if sound_id_override is not None else "-1"
732
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} Info Gain: {:.4g}% \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, \
733
+ len(self.target_positions), self.explored_rate, travel_dist, (100*self.info_gain/self.total_info), msk_name,
734
+ sound_id))
735
+ plt.tight_layout()
736
+ plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100))
737
+ # plt.show()
738
+ frame = '{}/{}_{}_samples.png'.format(path, n, step)
739
+ self.frame_files.append(frame)
740
+ plt.close()
741
+
742
+
743
+ ####################
744
+
745
+ class CustomNorm(Normalize):
746
+ """
747
+ A custom normalization that allocates a larger fraction of the colormap
748
+ to the lower data range [vmin, mid] than to [mid, vmax].
749
+
750
+ Parameters
751
+ ----------
752
+ vmin : float
753
+ Minimum data value
754
+ vmax : float
755
+ Maximum data value
756
+ mid : float
757
+ Midpoint in data where we switch from 'lower' to 'upper' mapping
758
+ lower_portion : float
759
+ Fraction of the colormap to allocate for [vmin, mid].
760
+ For example, 0.8 => 80% of colors for [vmin, mid], 20% for [mid, vmax].
761
+ clip : bool
762
+ Whether to clip data outside [vmin, vmax].
763
+ """
764
+
765
+ def __init__(self, vmin=None, vmax=None, mid=0.05, lower_portion=0.8, clip=False):
766
+ self.mid = mid
767
+ self.lower_portion = lower_portion
768
+ super().__init__(vmin, vmax, clip)
769
+
770
+ def __call__(self, value, clip=None):
771
+ """Forward transform: data -> [0..1] color space."""
772
+ vmin, vmax, mid = self.vmin, self.vmax, self.mid
773
+ lp = self.lower_portion
774
+
775
+ value = np.asarray(value, dtype=np.float64)
776
+
777
+ # Piecewise linear mapping:
778
+ # [vmin..mid] => [0..lp]
779
+ # [mid..vmax] => [lp..1]
780
+ normed = np.where(
781
+ value <= mid,
782
+ lp * (value - vmin) / (mid - vmin),
783
+ lp + (value - mid) / (vmax - mid) * (1 - lp)
784
+ )
785
+ return np.clip(normed, 0, 1)
786
+
787
+ def inverse(self, value):
788
+ """
789
+ Inverse transform: [0..1] color space -> data space.
790
+ Matplotlib's colorbar calls this to place ticks correctly.
791
+ """
792
+ vmin, vmax, mid = self.vmin, self.vmax, self.mid
793
+ lp = self.lower_portion
794
+
795
+ value = np.asarray(value, dtype=np.float64)
796
+
797
+ # For color space [0..lp], invert to [vmin..mid]
798
+ # For color space [lp..1], invert to [mid..vmax]
799
+ below = (value <= lp)
800
+ above = ~below
801
+
802
+ # Allocate array for results
803
+ data = np.zeros_like(value, dtype=np.float64)
804
+
805
+ # Invert lower segment
806
+ data[below] = vmin + (value[below] / lp) * (mid - vmin)
807
+
808
+ # Invert upper segment
809
+ data[above] = mid + ((value[above] - lp) / (1 - lp)) * (vmax - mid)
810
+
811
+ return data
graph.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: env.py
3
+ #
4
+ # - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31
5
+ # - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra)
6
+ #######################################################################
7
+
8
+ from decimal import Decimal
9
+
10
+
11
+ class Edge:
12
+ def __init__(self, to_node, length):
13
+ self.to_node = to_node
14
+ self.length = length
15
+
16
+
17
+ class Graph:
18
+ def __init__(self):
19
+ self.nodes = set()
20
+ self.edges = dict()
21
+
22
+ def add_node(self, node):
23
+ self.nodes.add(node)
24
+
25
+ def add_edge(self, from_node, to_node, length):
26
+ edge = Edge(to_node, length)
27
+ # edge = to_node
28
+ if from_node in self.edges:
29
+ from_node_edges = self.edges[from_node]
30
+ else:
31
+ self.edges[from_node] = dict()
32
+ from_node_edges = self.edges[from_node]
33
+ # if edge not in from_node_edges:
34
+ # from_node_edges.append(edge)
35
+ from_node_edges[to_node] = edge
36
+
37
+ def clear_edge(self, from_node):
38
+ if from_node in self.edges:
39
+ self.edges[from_node] = dict()
40
+ # else:
41
+ # print("no edge node or no this node")
42
+
43
+ def min_dist(q, dist):
44
+ """
45
+ Returns the node with the smallest distance in q.
46
+ Implemented to keep the main algorithm clean.
47
+ """
48
+ min_node = None
49
+ for node in q:
50
+ if min_node == None:
51
+ min_node = node
52
+ elif dist[node] < dist[min_node]:
53
+ min_node = node
54
+
55
+ return min_node
56
+
57
+
58
+ INFINITY = float('Infinity')
59
+
60
+
61
+ def dijkstra(graph, source):
62
+ q = set()
63
+ dist = {}
64
+ prev = {}
65
+
66
+ for v in graph.nodes: # initialization
67
+ dist[v] = INFINITY # unknown distance from source to v
68
+ prev[v] = INFINITY # previous node in optimal path from source
69
+ q.add(v) # all nodes initially in q (unvisited nodes)
70
+
71
+ # distance from source to source
72
+ dist[source] = 0
73
+
74
+ while q:
75
+ # node with the least distance selected first
76
+ u = min_dist(q, dist)
77
+
78
+ q.remove(u)
79
+
80
+ try:
81
+ if u in graph.edges:
82
+ for _, v in graph.edges[u].items():
83
+ alt = dist[u] + v.length
84
+ if alt < dist[v.to_node]:
85
+ # a shorter path to v has been found
86
+ dist[v.to_node] = alt
87
+ prev[v.to_node] = u
88
+ except:
89
+ pass
90
+
91
+ return dist, prev
92
+
93
+
94
+ def to_array(prev, from_node):
95
+ """Creates an ordered list of labels as a route."""
96
+ previous_node = prev[from_node]
97
+ route = [from_node]
98
+
99
+ while previous_node != INFINITY:
100
+ route.append(previous_node)
101
+ temp = previous_node
102
+ previous_node = prev[temp]
103
+
104
+ route.reverse()
105
+ return route
106
+
107
+
108
+ def h(index, destination, node_coords):
109
+ current = node_coords[index]
110
+ end = node_coords[destination]
111
+ h = abs(end[0] - current[0]) + abs(end[1] - current[1])
112
+ # h = ((end[0]-current[0])**2 + (end[1] - current[1])**2)**(1/2)
113
+ return h
114
+
115
+
116
+ def a_star(start, destination, node_coords, graph):
117
+ if start == destination:
118
+ return [], 0
119
+ if str(destination) in graph.edges[str(start)].keys():
120
+ cost = graph.edges[str(start)][str(destination)].length
121
+ return [start, destination], cost
122
+ open_list = {start}
123
+ closed_list = set([])
124
+
125
+ g = {start: 0}
126
+ parents = {start: start}
127
+
128
+ while len(open_list) > 0:
129
+ n = None
130
+ h_n = 1e5
131
+ # print('open list', open_list)
132
+ for v in open_list:
133
+ h_v = h(v, destination, node_coords)
134
+ if n is not None:
135
+ h_n = h(n, destination, node_coords)
136
+ if n is None or g[v] + h_v < g[n] + h_n:
137
+ n = v
138
+
139
+ if n is None:
140
+ print('Path does not exist!')
141
+ return None, 1e5
142
+
143
+ if n == destination:
144
+ reconst_path = []
145
+ while parents[n] != n:
146
+ reconst_path.append(n)
147
+ n = parents[n]
148
+ reconst_path.append(start)
149
+ reconst_path.reverse()
150
+ # print('Path found: {}'.format(reconst_path))
151
+ # print(g[destination])
152
+ return reconst_path, g[destination]
153
+
154
+ for edge in graph.edges[str(n)].values():
155
+ m = int(edge.to_node)
156
+ cost = edge.length
157
+ # print(m, cost)
158
+ if m not in open_list and m not in closed_list:
159
+ open_list.add(m)
160
+ parents[m] = n
161
+ g[m] = g[n] + cost
162
+
163
+ else:
164
+ if g[m] > g[n] + cost:
165
+ g[m] = g[n] + cost
166
+ parents[m] = n
167
+
168
+ if m in closed_list:
169
+ closed_list.remove(m)
170
+ open_list.add(m)
171
+
172
+ open_list.remove(n)
173
+ closed_list.add(n)
174
+
175
+ print('Path does not exist!')
176
+ return None, 1e5
177
+
178
+
179
+
graph_generator.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: graph_generator.py
3
+ #
4
+ # - Wrapper for graph.py
5
+ # - Sends the formatted inputs into graph.py to get useful info
6
+ #######################################################################
7
+
8
+ import sys
9
+ if sys.modules['TRAINING']:
10
+ from parameter import *
11
+ else:
12
+ from test_parameter import *
13
+
14
+ import numpy as np
15
+ from sklearn.neighbors import NearestNeighbors
16
+ import shapely.geometry
17
+
18
+ from node import Node
19
+ from graph import Graph, a_star
20
+
21
+
22
+ class Graph_generator:
23
+ def __init__(self, map_size, k_size, sensor_range, plot=False):
24
+ self.k_size = k_size
25
+ self.graph = Graph()
26
+ self.node_coords = None
27
+ self.plot = plot
28
+ self.x = []
29
+ self.y = []
30
+ self.map_x = map_size[1]
31
+ self.map_y = map_size[0]
32
+ self.uniform_points, self.grid_coords = self.generate_uniform_points()
33
+
34
+ self.sensor_range = sensor_range
35
+
36
+ self.route_node = []
37
+ self.nodes_list = []
38
+ self.node_utility = None
39
+ self.guidepost = None
40
+
41
+ def edge_clear_all_nodes(self):
42
+ self.graph = Graph()
43
+ self.x = []
44
+ self.y = []
45
+
46
+ def edge_clear(self, coords):
47
+ node_index = str(self.find_index_from_coords(self.node_coords, coords))
48
+ self.graph.clear_edge(node_index)
49
+
50
+ def generate_graph(self, robot_belief, frontiers):
51
+ self.edge_clear_all_nodes()
52
+ free_area = self.free_area(robot_belief)
53
+
54
+ free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j
55
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
56
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
57
+ node_coords = self.uniform_points[candidate_indices]
58
+
59
+ # node_coords = np.concatenate((robot_location.reshape(1, 2), node_coords))
60
+ self.node_coords = self.unique_coords(node_coords).reshape(-1, 2)
61
+ # self.find_k_neighbor_all_nodes(self.node_coords, robot_belief)
62
+ self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief)
63
+
64
+ self.node_utility = []
65
+ for coords in self.node_coords:
66
+ node = Node(coords, frontiers, robot_belief)
67
+ self.nodes_list.append(node)
68
+ utility = node.utility
69
+ self.node_utility.append(utility)
70
+ self.node_utility = np.array(self.node_utility)
71
+
72
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
73
+ x = self.node_coords[:,0] + self.node_coords[:,1]*1j
74
+ for node in self.route_node:
75
+ index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0]
76
+ self.guidepost[index] = 1
77
+
78
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
79
+
80
+ def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers):
81
+ new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255)
82
+ free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j
83
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
84
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
85
+ new_node_coords = self.uniform_points[candidate_indices]
86
+ self.node_coords = np.concatenate((self.node_coords, new_node_coords))
87
+
88
+ old_node_to_update = []
89
+ for coords in new_node_coords:
90
+ neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief)
91
+ old_node_to_update += neighbor_indices
92
+ old_node_to_update = set(old_node_to_update)
93
+ for index in old_node_to_update:
94
+ coords = self.node_coords[index]
95
+ self.edge_clear(coords)
96
+ self.find_k_neighbor(coords, self.node_coords, robot_belief)
97
+
98
+ #self.edge_clear_all_nodes()
99
+ #self.find_k_neighbor_all_nodes(self.node_coords, robot_belief)
100
+
101
+ old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
102
+ new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
103
+ observed_frontiers_index = np.where(
104
+ np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False)
105
+ new_frontiers_index = np.where(
106
+ np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False)
107
+ observed_frontiers = old_frontiers[observed_frontiers_index]
108
+ new_frontiers = frontiers[new_frontiers_index]
109
+ for node in self.nodes_list:
110
+ if node.zero_utility_node is True:
111
+ pass
112
+ else:
113
+ node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief)
114
+
115
+ for new_coords in new_node_coords:
116
+ node = Node(new_coords, frontiers, robot_belief)
117
+ self.nodes_list.append(node)
118
+
119
+ self.node_utility = []
120
+ for i, coords in enumerate(self.node_coords):
121
+ utility = self.nodes_list[i].utility
122
+ self.node_utility.append(utility)
123
+ self.node_utility = np.array(self.node_utility)
124
+
125
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
126
+ x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j
127
+ for node in self.route_node:
128
+ index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j)
129
+ self.guidepost[index] = 1
130
+
131
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
132
+
133
+ def generate_uniform_points(self):
134
+ # x = np.linspace(0, self.map_x - 1, NUM_COORDS_WIDTH).round().astype(int)
135
+ # y = np.linspace(0, self.map_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
136
+ padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH)
137
+ padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT)
138
+ x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int)
139
+ y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
140
+
141
+ t1, t2 = np.meshgrid(x, y)
142
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
143
+ matrix = np.stack((t1, t2), axis=-1)
144
+ return points, matrix
145
+
146
+ def free_area(self, robot_belief):
147
+ index = np.where(robot_belief == 255)
148
+ free = np.asarray([index[1], index[0]]).T
149
+ return free
150
+
151
+ def unique_coords(self, coords):
152
+ x = coords[:, 0] + coords[:, 1] * 1j
153
+ indices = np.unique(x, return_index=True)[1]
154
+ coords = np.array([coords[idx] for idx in sorted(indices)])
155
+ return coords
156
+
157
+ def find_k_neighbor(self, coords, node_coords, robot_belief):
158
+ dist_list = np.linalg.norm((coords-node_coords), axis=-1)
159
+ sorted_index = np.argsort(dist_list)
160
+ k = 0
161
+ neighbor_index_list = []
162
+ while k < self.k_size and k< node_coords.shape[0]:
163
+ neighbor_index = sorted_index[k]
164
+ neighbor_index_list.append(neighbor_index)
165
+ dist = dist_list[k]
166
+ start = coords
167
+ end = node_coords[neighbor_index]
168
+ if not self.check_collision(start, end, robot_belief):
169
+ a = str(self.find_index_from_coords(node_coords, start))
170
+ b = str(neighbor_index)
171
+ self.graph.add_node(a)
172
+ self.graph.add_edge(a, b, dist)
173
+
174
+ if self.plot:
175
+ self.x.append([start[0], end[0]])
176
+ self.y.append([start[1], end[1]])
177
+ k += 1
178
+ return neighbor_index_list
179
+
180
+ def find_k_neighbor_all_nodes(self, node_coords, robot_belief):
181
+ X = node_coords
182
+ if len(node_coords) >= self.k_size:
183
+ knn = NearestNeighbors(n_neighbors=self.k_size)
184
+ else:
185
+ knn = NearestNeighbors(n_neighbors=len(node_coords))
186
+ knn.fit(X)
187
+ distances, indices = knn.kneighbors(X)
188
+
189
+ for i, p in enumerate(X):
190
+ for j, neighbour in enumerate(X[indices[i][:]]):
191
+ start = p
192
+ end = neighbour
193
+ if not self.check_collision(start, end, robot_belief):
194
+ a = str(self.find_index_from_coords(node_coords, p))
195
+ b = str(self.find_index_from_coords(node_coords, neighbour))
196
+ self.graph.add_node(a)
197
+ self.graph.add_edge(a, b, distances[i, j])
198
+
199
+ if self.plot:
200
+ self.x.append([p[0], neighbour[0]])
201
+ self.y.append([p[1], neighbour[1]])
202
+
203
+ def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief):
204
+ for i, p in enumerate(node_coords):
205
+ filtered_coords = self.get_neighbors_grid_coords(p)
206
+
207
+ for j, neighbour in enumerate(filtered_coords):
208
+ start = p
209
+ end = neighbour
210
+ if not self.check_collision(start, end, robot_belief):
211
+ a = str(self.find_index_from_coords(node_coords, p))
212
+ b = str(self.find_index_from_coords(node_coords, neighbour))
213
+ self.graph.add_node(a)
214
+ self.graph.add_edge(a, b, np.linalg.norm(start-end))
215
+
216
+ if self.plot:
217
+ self.x.append([p[0], neighbour[0]])
218
+ self.y.append([p[1], neighbour[1]])
219
+
220
+
221
+
222
+ def find_index_from_coords(self, node_coords, p):
223
+ return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0]
224
+
225
+ def find_closest_index_from_coords(self, node_coords, p):
226
+ return np.argmin(np.linalg.norm(node_coords - p, axis=1))
227
+
228
+ def find_index_from_grid_coords_2d(self, p):
229
+
230
+ # Calculate the distance between the target coord and each point in grid_coords
231
+ diffs = np.linalg.norm(self.grid_coords - p, axis=2) # Compute along the last axis (x, y)
232
+ indices = np.where(diffs < 1e-5)
233
+
234
+ # Return the 2D index as a tuple (row, col)
235
+ if indices[0].size > 0:
236
+ return indices[0][0], indices[1][0]
237
+ else:
238
+ raise ValueError(f"Coordinate {p} not found in self.grid_coords.")
239
+
240
+ def find_closest_index_from_grid_coords_2d(self, p):
241
+
242
+ # Calculate the distance between the target coord and each point in grid_coords
243
+ distances = np.linalg.norm(self.grid_coords - p, axis=2) # Compute along the last axis (x, y)
244
+ flat_index = np.argmin(distances)
245
+ return np.unravel_index(flat_index, distances.shape)
246
+
247
+
248
+ def check_collision(self, start, end, robot_belief):
249
+ collision = False
250
+ line = shapely.geometry.LineString([start, end])
251
+
252
+ sortx = np.sort([start[0], end[0]])
253
+ sorty = np.sort([start[1], end[1]])
254
+
255
+ # print(robot_belief.shape)
256
+ robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1]
257
+
258
+ occupied_area_index = np.where(robot_belief == 1)
259
+ occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T
260
+ unexplored_area_index = np.where(robot_belief == 127)
261
+ unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T
262
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
263
+
264
+ # obstacles = []
265
+ for i in range(unfree_area_coords.shape[0]):
266
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
267
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
268
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
269
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
270
+ obstacle = shapely.geometry.Polygon(coords)
271
+ #obstacles.append(obstacle)
272
+ # if obstacles != []:
273
+ #all_obstacles = shapely.geometry.MultiPolygon(obstacles)
274
+ # print(obstacle.is_valid)
275
+ collision = line.intersects(obstacle)
276
+ if collision:
277
+ break
278
+
279
+ return collision
280
+
281
+ def find_shortest_path(self, current, destination, node_coords):
282
+ start_node = str(self.find_index_from_coords(node_coords, current))
283
+ end_node = str(self.find_index_from_coords(node_coords, destination))
284
+ route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph)
285
+ if start_node != end_node:
286
+ assert route != []
287
+ # t3 = time.time()
288
+ # print(t2-t1, t3-t2)
289
+ route = list(map(str, route))
290
+ return dist, route
291
+
292
+
293
+ # def get_neighbors_grid_coords(self, coord):
294
+ # # Return the 4 closest neighbors (N, S, E, W) of a given coordinate
295
+
296
+ # nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
297
+ # rows, cols = self.grid_coords.shape[:2]
298
+ # neighbors = []
299
+ # i, j = self.find_index_from_grid_coords_2d(nearest_coord)
300
+
301
+ # # Define NSEW neighbor offsets
302
+ # directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # N, S, W, E
303
+
304
+ # for di, dj in directions:
305
+ # ni, nj = i + di, j + dj
306
+ # if 0 <= ni < rows and 0 <= nj < cols:
307
+ # neighbors.append(tuple(self.grid_coords[ni, nj]))
308
+
309
+ # return neighbors
310
+
311
+
312
+ def get_neighbors_grid_coords(self, coord):
313
+ # Return the 8 closest neighbors of a given coordinate
314
+
315
+ nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
316
+ rows, cols = self.grid_coords.shape[:2]
317
+ neighbors = []
318
+ i, j = self.find_index_from_grid_coords_2d(nearest_coord)
319
+
320
+ # Create a range of indices for rows and columns
321
+ row_range = np.clip([i - 1, i, i + 1], 0, rows - 1)
322
+ col_range = np.clip([j - 1, j, j + 1], 0, cols - 1)
323
+
324
+ # Iterate over the valid indices
325
+ for ni in row_range:
326
+ for nj in col_range:
327
+ if (ni, nj) != (i, j): # Skip the center point
328
+ neighbors.append(tuple(self.grid_coords[ni, nj]))
329
+
330
+ return neighbors
inference/model/STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a
3
+ size 52167246
model.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: model.py
3
+ #
4
+ # - Attention-based encoders & decoders
5
+ # - Policy Net: Input = Augmented Graph, Output = Node to go to
6
+ # - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
7
+ #######################################################################
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import math
12
+
13
+
14
+ class SingleHeadAttention(nn.Module):
15
+ def __init__(self, embedding_dim):
16
+ super(SingleHeadAttention, self).__init__()
17
+ self.input_dim = embedding_dim
18
+ self.embedding_dim = embedding_dim
19
+ self.value_dim = embedding_dim
20
+ self.key_dim = self.value_dim
21
+ self.tanh_clipping = 10
22
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
23
+
24
+ self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
25
+ self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
26
+
27
+ self.init_parameters()
28
+
29
+ def init_parameters(self):
30
+ for param in self.parameters():
31
+ stdv = 1. / math.sqrt(param.size(-1))
32
+ param.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, q, k, mask=None):
35
+
36
+ n_batch, n_key, n_dim = k.size()
37
+ n_query = q.size(1)
38
+
39
+ k_flat = k.reshape(-1, n_dim)
40
+ q_flat = q.reshape(-1, n_dim)
41
+
42
+ shape_k = (n_batch, n_key, -1)
43
+ shape_q = (n_batch, n_query, -1)
44
+
45
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q)
46
+ K = torch.matmul(k_flat, self.w_key).view(shape_k)
47
+
48
+ U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
49
+ U = self.tanh_clipping * torch.tanh(U)
50
+
51
+ if mask is not None:
52
+ U = U.masked_fill(mask == 1, -1e8)
53
+ attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
54
+
55
+ return attention
56
+
57
+
58
+ class MultiHeadAttention(nn.Module):
59
+ def __init__(self, embedding_dim, n_heads=8):
60
+ super(MultiHeadAttention, self).__init__()
61
+ self.n_heads = n_heads
62
+ self.input_dim = embedding_dim
63
+ self.embedding_dim = embedding_dim
64
+ self.value_dim = self.embedding_dim // self.n_heads
65
+ self.key_dim = self.value_dim
66
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
67
+
68
+ self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
69
+ self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
70
+ self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
71
+ self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
72
+
73
+ self.init_parameters()
74
+
75
+ def init_parameters(self):
76
+ for param in self.parameters():
77
+ stdv = 1. / math.sqrt(param.size(-1))
78
+ param.data.uniform_(-stdv, stdv)
79
+
80
+ def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
81
+ if k is None:
82
+ k = q
83
+ if v is None:
84
+ v = q
85
+
86
+ n_batch, n_key, n_dim = k.size()
87
+ n_query = q.size(1)
88
+ n_value = v.size(1)
89
+
90
+ k_flat = k.contiguous().view(-1, n_dim)
91
+ v_flat = v.contiguous().view(-1, n_dim)
92
+ q_flat = q.contiguous().view(-1, n_dim)
93
+ shape_v = (self.n_heads, n_batch, n_value, -1)
94
+ shape_k = (self.n_heads, n_batch, n_key, -1)
95
+ shape_q = (self.n_heads, n_batch, n_query, -1)
96
+
97
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
98
+ K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
99
+ V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
100
+
101
+ U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
102
+
103
+ if attn_mask is not None:
104
+ attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
105
+
106
+ if key_padding_mask is not None:
107
+ key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
108
+ key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
109
+
110
+ if attn_mask is not None and key_padding_mask is not None:
111
+ mask = (attn_mask + key_padding_mask)
112
+ elif attn_mask is not None:
113
+ mask = attn_mask
114
+ elif key_padding_mask is not None:
115
+ mask = key_padding_mask
116
+ else:
117
+ mask = None
118
+
119
+ if mask is not None:
120
+ U = U.masked_fill(mask > 0, -1e8)
121
+
122
+ attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
123
+
124
+ heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
125
+
126
+ # out = heads.permute(1, 2, 0, 3).reshape(n_batch, n_query, n_dim)
127
+ out = torch.mm(
128
+ heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
129
+ # batch_size*n_query*n_heads*value_dim
130
+ self.w_out.view(-1, self.embedding_dim)
131
+ # n_heads*value_dim*embedding_dim
132
+ ).view(-1, n_query, self.embedding_dim)
133
+
134
+
135
+ return out, attention # batch_size*n_query*embedding_dim
136
+
137
+
138
+ class Normalization(nn.Module):
139
+ def __init__(self, embedding_dim):
140
+ super(Normalization, self).__init__()
141
+ self.normalizer = nn.LayerNorm(embedding_dim)
142
+
143
+ def forward(self, input):
144
+ return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
145
+
146
+
147
+ class EncoderLayer(nn.Module):
148
+ def __init__(self, embedding_dim, n_head):
149
+ super(EncoderLayer, self).__init__()
150
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
151
+ self.normalization1 = Normalization(embedding_dim)
152
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
153
+ nn.Linear(512, embedding_dim))
154
+ self.normalization2 = Normalization(embedding_dim)
155
+
156
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
157
+ h0 = src
158
+ h = self.normalization1(src)
159
+ h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
160
+ h = h + h0
161
+ h1 = h
162
+ h = self.normalization2(h)
163
+ h = self.feedForward(h)
164
+ h2 = h + h1
165
+ return h2
166
+
167
+
168
+ class DecoderLayer(nn.Module):
169
+ def __init__(self, embedding_dim, n_head):
170
+ super(DecoderLayer, self).__init__()
171
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
172
+ self.normalization1 = Normalization(embedding_dim)
173
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
174
+ nn.ReLU(inplace=True),
175
+ nn.Linear(512, embedding_dim))
176
+ self.normalization2 = Normalization(embedding_dim)
177
+
178
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
179
+ h0 = tgt
180
+ tgt = self.normalization1(tgt)
181
+ memory = self.normalization1(memory)
182
+ h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
183
+ h = h + h0
184
+ h1 = h
185
+ h = self.normalization2(h)
186
+ h = self.feedForward(h)
187
+ h2 = h + h1
188
+ return h2, w
189
+
190
+
191
+ class Encoder(nn.Module):
192
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
193
+ super(Encoder, self).__init__()
194
+ self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
195
+
196
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
197
+ for layer in self.layers:
198
+ src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
199
+ return src
200
+
201
+
202
+ class Decoder(nn.Module):
203
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
204
+ super(Decoder, self).__init__()
205
+ self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
206
+
207
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
208
+ for layer in self.layers:
209
+ tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
210
+ return tgt, w
211
+
212
+
213
+ class PolicyNet(nn.Module):
214
+ def __init__(self, input_dim, embedding_dim):
215
+ super(PolicyNet, self).__init__()
216
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
217
+
218
+ self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
219
+
220
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
221
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
222
+ self.pointer = SingleHeadAttention(embedding_dim)
223
+
224
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
225
+ node_feature = self.initial_embedding(node_inputs)
226
+ enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
227
+
228
+ return enhanced_node_feature
229
+
230
+ def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
231
+ k_size = edge_inputs.size()[2]
232
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
233
+ current_edge = current_edge.permute(0, 2, 1)
234
+ embedding_dim = enhanced_node_feature.size()[2]
235
+
236
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
237
+
238
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
239
+
240
+ if edge_padding_mask is not None:
241
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
242
+ # print(current_mask)
243
+ else:
244
+ current_mask = None
245
+ current_mask[:,:,0] = 1 # don't stay at current position
246
+
247
+ # ADDED: If nowhere to go, then STAY at current position
248
+ # #assert 0 in current_mask # Will cause sim to crash
249
+ if not 0 in current_mask:
250
+ current_mask[:,:,0] = 0
251
+
252
+ enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
253
+ enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
254
+ logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
255
+ logp= logp.squeeze(1) # batch_size*k_size
256
+
257
+ return logp
258
+
259
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
260
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
261
+ logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
262
+ return logp
263
+
264
+
265
+ class QNet(nn.Module):
266
+ def __init__(self, input_dim, embedding_dim):
267
+ super(QNet, self).__init__()
268
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
269
+ self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
270
+
271
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
272
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
273
+
274
+ self.q_values_layer = nn.Linear(embedding_dim, 1)
275
+
276
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
277
+ embedding_feature = self.initial_embedding(node_inputs)
278
+ embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
279
+
280
+ return embedding_feature
281
+
282
+ def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
283
+ k_size = edge_inputs.size()[2]
284
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
285
+ current_edge = current_edge.permute(0, 2, 1)
286
+ embedding_dim = enhanced_node_feature.size()[2]
287
+
288
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
289
+
290
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
291
+
292
+ enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
293
+ action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
294
+ action_features = self.action_embedding(action_features)
295
+ q_values = self.q_values_layer(action_features)
296
+
297
+ if edge_padding_mask is not None:
298
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
299
+ enhanced_node_feature.device)
300
+ else:
301
+ current_mask = None
302
+ current_mask[:, :, 0] = 1 # don't stay at current position
303
+
304
+ # assert 0 in current_mask # Will cause sim to crash
305
+ if not 0 in current_mask:
306
+ current_mask[:,:,0] = 0
307
+
308
+ current_mask = current_mask.permute(0, 2, 1)
309
+ zero = torch.zeros_like(q_values).to(q_values.device)
310
+ q_values = torch.where(current_mask == 1, zero, q_values)
311
+
312
+ return q_values, attention_weights
313
+
314
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
315
+ edge_mask=None):
316
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
317
+ q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
318
+ return q_values, attention_weights
319
+
node.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: node.py
3
+ #
4
+ # - Contains info per node on graph (edge)
5
+ # - Contains: Position, Utility, Visitation History
6
+ #######################################################################
7
+
8
+ import sys
9
+ if sys.modules['TRAINING']:
10
+ from parameter import *
11
+ else:
12
+ from test_parameter import *
13
+
14
+ import numpy as np
15
+ import shapely.geometry
16
+
17
+
18
+ class Node():
19
+ def __init__(self, coords, frontiers, robot_belief):
20
+ self.coords = coords
21
+ self.observable_frontiers = []
22
+ self.sensor_range = SENSOR_RANGE
23
+ self.initialize_observable_frontiers(frontiers, robot_belief)
24
+ self.utility = self.get_node_utility()
25
+ if self.utility == 0:
26
+ self.zero_utility_node = True
27
+ else:
28
+ self.zero_utility_node = False
29
+
30
+ def initialize_observable_frontiers(self, frontiers, robot_belief):
31
+ dist_list = np.linalg.norm(frontiers - self.coords, axis=-1)
32
+ frontiers_in_range = frontiers[dist_list < self.sensor_range - 10]
33
+ for point in frontiers_in_range:
34
+ collision = self.check_collision(self.coords, point, robot_belief)
35
+ if not collision:
36
+ self.observable_frontiers.append(point)
37
+
38
+ def get_node_utility(self):
39
+ return len(self.observable_frontiers)
40
+
41
+ def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief):
42
+ if observed_frontiers != []:
43
+ observed_index = []
44
+ for i, point in enumerate(self.observable_frontiers):
45
+ if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j:
46
+ observed_index.append(i)
47
+ for index in reversed(observed_index):
48
+ self.observable_frontiers.pop(index)
49
+ #
50
+ if new_frontiers != []:
51
+ dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1)
52
+ new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15]
53
+ for point in new_frontiers_in_range:
54
+ collision = self.check_collision(self.coords, point, robot_belief)
55
+ if not collision:
56
+ self.observable_frontiers.append(point)
57
+
58
+ self.utility = self.get_node_utility()
59
+ if self.utility == 0:
60
+ self.zero_utility_node = True
61
+ else:
62
+ self.zero_utility_node = False
63
+
64
+ def set_visited(self):
65
+ self.observable_frontiers = []
66
+ self.utility = 0
67
+ self.zero_utility_node = True
68
+
69
+ def check_collision(self, start, end, robot_belief):
70
+ collision = False
71
+ line = shapely.geometry.LineString([start, end])
72
+
73
+ sortx = np.sort([start[0], end[0]])
74
+ sorty = np.sort([start[1], end[1]])
75
+
76
+ # print(robot_belief.shape)
77
+ robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1]
78
+
79
+ occupied_area_index = np.where(robot_belief == 1)
80
+ occupied_area_coords = np.asarray(
81
+ [occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T
82
+ unexplored_area_index = np.where(robot_belief == 127)
83
+ unexplored_area_coords = np.asarray(
84
+ [unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T
85
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
86
+
87
+ # obstacles = []
88
+ for i in range(unfree_area_coords.shape[0]):
89
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
90
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
91
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
92
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
93
+ obstacle = shapely.geometry.Polygon(coords)
94
+ # obstacles.append(obstacle)
95
+ # if obstacles != []:
96
+ # all_obstacles = shapely.geometry.MultiPolygon(obstacles)
97
+ # print(obstacle.is_valid)
98
+ collision = line.intersects(obstacle)
99
+ if collision:
100
+ break
101
+
102
+ return collision
robot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: robot.py
3
+ #
4
+ # - Stores S(t), A(t), R(t), S(t+1)
5
+ #######################################################################
6
+
7
+ from copy import deepcopy
8
+ import torch
9
+
10
+ class Robot:
11
+ def __init__(self, robot_id, position, plot=False):
12
+ self.robot_id = robot_id
13
+ self.plot = plot
14
+ self.travel_dist = 0
15
+ self.robot_position = position
16
+ self.observations = None
17
+ self.trajectory_coords = []
18
+ self.targets_found_on_path = []
19
+
20
+ self.episode_buffer = []
21
+ for i in range(15):
22
+ self.episode_buffer.append([])
23
+
24
+ if self.plot:
25
+ # initialize the route
26
+ self.xPoints = [self.robot_position[0]]
27
+ self.yPoints = [self.robot_position[1]]
28
+
29
+ def save_observations(self, observations):
30
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
31
+ self.episode_buffer[0] += deepcopy(node_inputs).to('cpu')
32
+ self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu')
33
+ self.episode_buffer[2] += deepcopy(current_index).to('cpu')
34
+ self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu')
35
+ self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu')
36
+ self.episode_buffer[5] += deepcopy(edge_mask).to('cpu')
37
+
38
+ def save_action(self, action_index):
39
+ self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0)
40
+
41
+ def save_reward_done(self, reward, done):
42
+ self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu')
43
+ self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu')
44
+ if self.plot:
45
+ self.xPoints.append(self.robot_position[0])
46
+ self.yPoints.append(self.robot_position[1])
47
+
48
+ def save_next_observations(self, observations):
49
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
50
+ self.episode_buffer[9] += deepcopy(node_inputs).to('cpu')
51
+ self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu')
52
+ self.episode_buffer[11] += deepcopy(current_index).to('cpu')
53
+ self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu')
54
+ self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu')
55
+ self.episode_buffer[14] += deepcopy(edge_mask).to('cpu')
56
+
57
+ # NEW: ADDED TO SAVE TRAJECTORY COORDS DURING TTA
58
+ def save_trajectory_coords(self, robot_position_coords, num_target_found):
59
+ self.trajectory_coords.append(robot_position_coords)
60
+ self.targets_found_on_path.append(num_target_found)
sensor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: sensor.py
3
+ #
4
+ # - Computes sensor related checks (e.g. collision, utility etc)
5
+ #######################################################################
6
+
7
+ import sys
8
+ if sys.modules['TRAINING']:
9
+ from parameter import *
10
+ else:
11
+ from test_parameter import *
12
+
13
+ import math
14
+ import numpy as np
15
+ import copy
16
+
17
+ def collision_check(x0, y0, x1, y1, ground_truth, robot_belief):
18
+ x0 = x0.round()
19
+ y0 = y0.round()
20
+ x1 = x1.round()
21
+ y1 = y1.round()
22
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
23
+ x, y = x0, y0
24
+ error = dx - dy
25
+ x_inc = 1 if x1 > x0 else -1
26
+ y_inc = 1 if y1 > y0 else -1
27
+ dx *= 2
28
+ dy *= 2
29
+
30
+ collision_flag = 0
31
+ max_collision = 10
32
+
33
+ while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]:
34
+ k = ground_truth.item(y, x)
35
+ if k == 1 and collision_flag < max_collision:
36
+ collision_flag += 1
37
+ if collision_flag >= max_collision:
38
+ break
39
+
40
+ if k !=1 and collision_flag > 0:
41
+ break
42
+
43
+ if x == x1 and y == y1:
44
+ break
45
+
46
+ robot_belief.itemset((y, x), k)
47
+
48
+ if error > 0:
49
+ x += x_inc
50
+ error -= dy
51
+ else:
52
+ y += y_inc
53
+ error += dx
54
+
55
+ return robot_belief
56
+
57
+
58
+ def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL):
59
+ x0 = robot_position[0]
60
+ y0 = robot_position[1]
61
+ rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH)
62
+ rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT)
63
+
64
+ if sensor_model == "rectangular": # TODO: add collision check
65
+ max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1])
66
+ min_x = max(x0 - int(math.ceil(rng_x)), 0)
67
+ max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0])
68
+ min_y = max(y0 - int(math.ceil(rng_y)), 0)
69
+ robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x]
70
+ else:
71
+ sensor_angle_inc = 0.5 / 180 * np.pi
72
+ sensor_angle = 0
73
+ while sensor_angle < 2 * np.pi:
74
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
75
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
76
+ robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief)
77
+ sensor_angle += sensor_angle_inc
78
+ return robot_belief
79
+
80
+
81
+ def unexplored_area_check(x0, y0, x1, y1, current_belief):
82
+ x0 = x0.round()
83
+ y0 = y0.round()
84
+ x1 = x1.round()
85
+ y1 = y1.round()
86
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
87
+ x, y = x0, y0
88
+ error = dx - dy
89
+ x_inc = 1 if x1 > x0 else -1
90
+ y_inc = 1 if y1 > y0 else -1
91
+ dx *= 2
92
+ dy *= 2
93
+
94
+ while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]:
95
+ k = current_belief.item(y, x)
96
+ if x == x1 and y == y1:
97
+ break
98
+
99
+ if k == 1:
100
+ break
101
+
102
+ if k == 127:
103
+ current_belief.itemset((y, x), 0)
104
+ break
105
+
106
+ if error > 0:
107
+ x += x_inc
108
+ error -= dy
109
+ else:
110
+ y += y_inc
111
+ error += dx
112
+
113
+ return current_belief
114
+
115
+
116
+ def calculate_utility(waypoint_position, sensor_range, robot_belief):
117
+ sensor_angle_inc = 5 / 180 * np.pi
118
+ sensor_angle = 0
119
+ x0 = waypoint_position[0]
120
+ y0 = waypoint_position[1]
121
+ current_belief = copy.deepcopy(robot_belief)
122
+ while sensor_angle < 2 * np.pi:
123
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
124
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
125
+ current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief)
126
+ sensor_angle += sensor_angle_inc
127
+ utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127)
128
+ return utility
test_multi_robot_worker.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: multi_robot_worker.py
3
+ #
4
+ # - Runs robot in environment for N steps
5
+ # - Collects & Returns S(t), A(t), R(t), S(t+1)
6
+ # - NOTE: Applicable for multiple robots
7
+ #######################################################################
8
+
9
+ from pathlib import Path
10
+ from test_parameter import *
11
+
12
+ from time import time
13
+ import imageio
14
+ import csv
15
+ import os
16
+ import copy
17
+ import numpy as np
18
+ import torch
19
+ import matplotlib.pyplot as plt
20
+ import json
21
+ from env import Env
22
+ from model import PolicyNet
23
+ from robot import Robot
24
+ from Taxabind.Taxabind.SatBind.watershed_segmentation import WatershedBinomial
25
+ from Taxabind.Taxabind.SatBind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
26
+ from Taxabind.Taxabind.SatBind.clip_seg_tta import ClipSegTTA
27
+
28
+ np.seterr(invalid='raise', divide='raise')
29
+
30
+
31
+ class TestWorker:
32
+ def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
33
+ self.device = device
34
+ self.greedy = greedy
35
+ self.n_agent = n_agent
36
+ self.metaAgentID = meta_agent_id
37
+ self.global_step = global_step
38
+ self.k_size = K_SIZE
39
+ self.save_image = save_image
40
+ self.tta = TAXABIND_TTA
41
+ self.clip_seg_tta = clip_seg_tta
42
+
43
+ self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True)
44
+ self.local_policy_net = policy_net
45
+
46
+ self.robot_list = []
47
+ self.all_robot_positions = []
48
+ for i in range(self.n_agent):
49
+ # robot_position = self.env.node_coords[i]
50
+ robot_position = self.env.start_positions[i]
51
+ robot = Robot(robot_id=i, position=robot_position, plot=save_image)
52
+ self.robot_list.append(robot)
53
+ self.all_robot_positions.append(robot_position)
54
+
55
+ self.perf_metrics = dict()
56
+ self.bad_mask_init = False
57
+
58
+ # # TEMP - EXPORT START POSES FOR BASELINES
59
+ # json_path = "eval_start_positions.json"
60
+ # sat_to_start_pose_dict = {}
61
+ # print("len(self.env.map_list): ", len(self.env.map_list))
62
+ # for i in range(4000):
63
+ # print("i: ", i)
64
+ # map_idx = i % len(self.env.map_list)
65
+ # _, map_start_position = self.env.import_ground_truth(os.path.join(self.env.map_dir, self.env.map_list[map_idx]))
66
+ # self.clip_seg_tta.reset(sample_idx=i)
67
+ # sat_path = self.clip_seg_tta.gt_mask_name
68
+ # sat_to_start_pose_dict[sat_path] = tuple(map(int, map_start_position))
69
+ # # Save to json
70
+ # with open(json_path, 'w') as f:
71
+ # json.dump(sat_to_start_pose_dict, f)
72
+ # print("len(sat_to_start_pose_dict): ", len(sat_to_start_pose_dict))
73
+ # exit()
74
+
75
+ if self.tta:
76
+ # NOTE: Moved to test_driver.py for efficiency (avoid repeated init)
77
+ # self.clip_seg_tta = ClipSegTTA(
78
+ # img_dir=TAXABIND_IMG_DIR,
79
+ # imo_dir=TAXABIND_IMO_DIR,
80
+ # json_path=TAXABIND_INAT_JSON_PATH,
81
+ # sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
82
+ # patch_size=TAXABIND_PATCH_SIZE,
83
+ # sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
84
+ # sample_index = TAXABIND_SAMPLE_INDEX, #global_step
85
+ # blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
86
+ # device=self.device,
87
+ # sat_to_img_ids_json_is_train_dict=False # for search ds val
88
+ # # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
89
+ # )
90
+ if clip_seg_tta is not None:
91
+ self.clip_seg_tta.reset(sample_idx=self.global_step)
92
+ # print("Resetting for sample index: ", self.global_step)
93
+
94
+ # Override target positions in env
95
+ self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] # Must transpose to (y, x) format
96
+
97
+ # Override GT seg mask for info_gain metric
98
+ if OVERRIDE_GT_MASK_DIR != "":
99
+ self.tta_gt_seg_path = os.path.join(OVERRIDE_GT_MASK_DIR, self.clip_seg_tta.gt_mask_name)
100
+ print("self.clip_seg_tta.gt_mask_name: ", self.clip_seg_tta.gt_mask_name)
101
+ if os.path.exists(self.tta_gt_seg_path):
102
+ self.env.gt_segmentation_mask = self.env.import_segmentation_mask(self.tta_gt_seg_path)
103
+ else:
104
+ print("\n\n!!!!!! WARNING: GT mask not found at path: ", self.tta_gt_seg_path)
105
+
106
+ if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
107
+ # mask_name = self.clip_seg_tta.gt_mask_name.split('_')[:-TAX_HIERARCHY_TO_CONDENSE]
108
+ # mask_name = '_'.join(mask_name) + ".png"
109
+ score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
110
+ print("score_mask_path: ", score_mask_path)
111
+ if os.path.exists(score_mask_path):
112
+ self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
113
+ self.env.begin(self.env.map_start_position)
114
+ else:
115
+ print(f"\n\n{RED}!!!!!! ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
116
+ self.bad_mask_init = True
117
+
118
+
119
+
120
+ # # # # TEMP: Additional targets
121
+ # self.env.target_positions = [] # Reset
122
+ # print("self.env.target_positions", self.env.target_positions)
123
+ # # self.env.target_positions = [self.env.target_positions[2]]
124
+ # # self.env.target_positions = [self.env.target_positions[0]]
125
+ # # self.env.target_positions.append((251, 297))
126
+ # self.env.target_positions.append((40,40))
127
+ # self.env.target_positions.append((80,40))
128
+ # self.env.target_positions.append((120,40))
129
+ # self.env.target_positions.append((160,40))
130
+ # self.env.target_positions.append((200,40))
131
+
132
+ # Save clustered embeds from sat encoder
133
+ # In thery, we only need to do this once (same satellite map throughout)
134
+ if USE_CLIP_PREDS:
135
+ self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
136
+ k_min=1,
137
+ k_max=8,
138
+ k_avg_max=4,
139
+ silhouette_threshold=0.15,
140
+ relative_threshold=0.15,
141
+ random_state=0,
142
+ min_patch_size=5, # smoothing parameter
143
+ n_smooth_iter=2, # smoothing parameter
144
+ ignore_label=-1,
145
+ plot=self.save_image,
146
+ gifs_dir = gifs_path
147
+ )
148
+ # Fit & predict (this will also plot the clusters before & after smoothing)
149
+ map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
150
+ self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
151
+ patch_embeds=self.clip_seg_tta.patch_embeds,
152
+ map_shape=map_shape,
153
+ )
154
+ print("Chosen k:", self.kmeans_clusterer.final_k)
155
+ # print("Smoothed labels shape:", self.kmeans_sat_embeds_clusters.shape)
156
+
157
+ if EXECUTE_TTA:
158
+ print("Will execute TTA...")
159
+
160
+ # Define Poisson TTA params
161
+ self.step_since_tta = 0
162
+ self.steps_to_first_tgt = None
163
+ self.steps_to_mid_tgt = None
164
+ self.steps_to_last_tgt = None
165
+
166
+ self.sim_0perc = None
167
+ self.sim_25perc = None
168
+ self.sim_50perc = None
169
+ self.sim_75perc = None
170
+ self.sim_100perc = None
171
+
172
+
173
+ def run_episode(self, curr_episode):
174
+
175
+ # Return all metrics as None if faulty mask init
176
+ if self.bad_mask_init:
177
+ self.perf_metrics['tax'] = None
178
+ self.perf_metrics['tax_first'] = None
179
+ self.perf_metrics['travel_dist'] = None
180
+ self.perf_metrics['travel_steps'] = None
181
+ self.perf_metrics['steps_to_first_tgt'] = None
182
+ self.perf_metrics['steps_to_mid_tgt'] = None
183
+ self.perf_metrics['steps_to_last_tgt'] = None
184
+ self.perf_metrics['explored_rate'] = None
185
+ self.perf_metrics['targets_found'] = None
186
+ self.perf_metrics['targets_total'] = None
187
+ self.perf_metrics['sim_0perc'] = None
188
+ self.perf_metrics['sim_25perc'] = None
189
+ self.perf_metrics['sim_50perc'] = None
190
+ self.perf_metrics['sim_75perc'] = None
191
+ self.perf_metrics['sim_100perc'] = None
192
+ self.perf_metrics['kmeans_k'] = None
193
+ self.perf_metrics['tgts_gt_score'] = None
194
+ self.perf_metrics['clip_inference_time'] = None
195
+ self.perf_metrics['tta_time'] = None
196
+ self.perf_metrics['info_gain'] = None
197
+ self.perf_metrics['total_info'] = None
198
+ self.perf_metrics['success_rate'] = None
199
+ return
200
+
201
+ eps_start = time()
202
+ done = False
203
+ for robot_id, deciding_robot in enumerate(self.robot_list):
204
+ deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
205
+ if self.tta and USE_CLIP_PREDS:
206
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
207
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
208
+ # print("self.env.segmentation_info_mask.shape", self.env.segmentation_info_mask.shape)
209
+
210
+ ### Run episode for 128 steps ###
211
+ for step in range(NUM_EPS_STEPS):
212
+
213
+ # print("\n\n\n~~~~~~~~~~~~~~~~~~~~~~ Step: ", step, " ~~~~~~~~~~~~~~~~~~~~~~")
214
+
215
+ next_position_list = []
216
+ dist_list = []
217
+ travel_dist_list = []
218
+ dist_array = np.zeros((self.n_agent, 1))
219
+ for robot_id, deciding_robot in enumerate(self.robot_list):
220
+ observations = deciding_robot.observations
221
+ # if self.env.node_coords.shape[0] >= self.k_size:
222
+ # deciding_robot.save_observations(observations)
223
+
224
+ ### Forward pass through policy to get next position ###
225
+ next_position, action_index = self.select_node(observations)
226
+ # if self.env.node_coords.shape[0] >= self.k_size:
227
+ # deciding_robot.save_action(action_index)
228
+
229
+ dist = np.linalg.norm(next_position - deciding_robot.robot_position)
230
+
231
+ ### Log results of action (e.g. distance travelled) ###
232
+ dist_array[robot_id] = dist
233
+ dist_list.append(dist)
234
+ travel_dist_list.append(deciding_robot.travel_dist)
235
+ next_position_list.append(next_position)
236
+ self.all_robot_positions[robot_id] = next_position
237
+
238
+ arriving_sequence = np.argsort(dist_list)
239
+ next_position_list = np.array(next_position_list)
240
+ dist_list = np.array(dist_list)
241
+ travel_dist_list = np.array(travel_dist_list)
242
+ next_position_list = next_position_list[arriving_sequence]
243
+ dist_list = dist_list[arriving_sequence]
244
+ travel_dist_list = travel_dist_list[arriving_sequence]
245
+
246
+ ### Take Action (Deconflict if 2 agents choose the same target position) ###
247
+ next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
248
+ # dist_travelled = np.linalg.norm(next_position - deciding_robot.robot_position)
249
+ # deciding_robot.travel_dist += dist_travelled
250
+ # deciding_robot.robot_position = next_position
251
+ reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
252
+
253
+ ### Update observations + rewards from action ###
254
+ for reward, robot_id in zip(reward_list, arriving_sequence):
255
+ robot = self.robot_list[robot_id]
256
+ robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
257
+
258
+ # # TTA Update via Poisson Test (with KMeans clustering stats)
259
+ if self.tta and USE_CLIP_PREDS and EXECUTE_TTA:
260
+ self.poisson_tta_update(robot, self.global_step, step)
261
+
262
+ robot.observations = self.get_observations(robot.robot_position)
263
+ # if self.env.node_coords.shape[0] >= self.k_size:
264
+ robot.save_reward_done(reward, done)
265
+ # robot.save_next_observations(robot.observations)
266
+
267
+ # Update metrics
268
+ # NOTE: For 1 robot for now
269
+ self.log_metrics(step=step) # robot.targets_found_on_path)
270
+
271
+
272
+ ### Save a frame to generate gif of robot trajectories ###
273
+ if self.save_image:
274
+ robots_route = []
275
+ for robot in self.robot_list:
276
+ robots_route.append([robot.xPoints, robot.yPoints])
277
+ if not os.path.exists(gifs_path):
278
+ os.makedirs(gifs_path)
279
+ sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
280
+ if TAXABIND_TTA and USE_CLIP_PREDS:
281
+ self.env.plot_env(
282
+ self.global_step,
283
+ gifs_path,
284
+ step,
285
+ max(travel_dist_list),
286
+ robots_route,
287
+ img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
288
+ sat_path_override=self.clip_seg_tta.imo_path,
289
+ msk_name_override=self.clip_seg_tta.species_name,
290
+ sound_id_override=sound_id_override,
291
+ colormap_mid_val=np.max(self.clip_seg_tta.heatmap_unnormalized_initial)
292
+ )
293
+ else:
294
+ self.env.plot_env(
295
+ self.global_step,
296
+ gifs_path,
297
+ step,
298
+ max(travel_dist_list),
299
+ robots_route,
300
+ img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
301
+ sat_path_override=self.clip_seg_tta.imo_path,
302
+ msk_name_override=self.clip_seg_tta.species_name,
303
+ sound_id_override=sound_id_override,
304
+ )
305
+
306
+ if done:
307
+ break
308
+
309
+ if self.tta:
310
+ tax = Path(self.clip_seg_tta.gt_mask_name).stem
311
+ self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
312
+ self.perf_metrics['tax_first'] = tax.split("_")[1]
313
+ else:
314
+ self.perf_metrics['tax'] = None
315
+ self.perf_metrics['tax_first'] = None
316
+ self.perf_metrics['travel_dist'] = max(travel_dist_list)
317
+ self.perf_metrics['travel_steps'] = step + 1
318
+ self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
319
+ self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
320
+ self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
321
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
322
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
323
+ self.perf_metrics['targets_total'] = len(self.env.target_positions)
324
+ self.perf_metrics['sim_0perc'] = self.sim_0perc
325
+ self.perf_metrics['sim_25perc'] = self.sim_25perc
326
+ self.perf_metrics['sim_50perc'] = self.sim_50perc
327
+ self.perf_metrics['sim_75perc'] = self.sim_75perc
328
+ self.perf_metrics['sim_100perc'] = self.sim_100perc
329
+ if USE_CLIP_PREDS:
330
+ self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
331
+ self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
332
+ self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
333
+ self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
334
+ else:
335
+ self.perf_metrics['kmeans_k'] = None
336
+ self.perf_metrics['tgts_gt_score'] = None
337
+ self.perf_metrics['clip_inference_time'] = None
338
+ self.perf_metrics['tta_time'] = None
339
+ if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path):
340
+ self.perf_metrics['info_gain'] = self.env.info_gain
341
+ self.perf_metrics['total_info'] = self.env.total_info
342
+ else:
343
+ self.perf_metrics['info_gain'] = None
344
+ self.perf_metrics['total_info'] = None
345
+ if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
346
+ self.perf_metrics['success_rate'] = True
347
+ else:
348
+ self.perf_metrics['success_rate'] = done
349
+
350
+ # save gif
351
+ if self.save_image:
352
+ path = gifs_path
353
+ self.make_gif(path, curr_episode)
354
+
355
+ print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
356
+
357
+ def get_observations(self, robot_position):
358
+ """ Get robot's sensor observation of environment given position """
359
+ current_node_index = self.env.find_index_from_coords(robot_position)
360
+ current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
361
+
362
+ node_coords = copy.deepcopy(self.env.node_coords)
363
+ graph = copy.deepcopy(self.env.graph)
364
+ node_utility = copy.deepcopy(self.env.node_utility)
365
+ guidepost = copy.deepcopy(self.env.guidepost)
366
+ # segmentation_info_mask = copy.deepcopy(self.env.segmentation_info_mask)
367
+ segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
368
+
369
+ # ADDED - SEGMENTATION INFORATION MASK
370
+ n_nodes = node_coords.shape[0]
371
+
372
+ node_coords = node_coords / 640
373
+ node_utility = node_utility / 50
374
+
375
+ node_utility_inputs = node_utility.reshape((n_nodes, 1))
376
+
377
+ occupied_node = np.zeros((n_nodes, 1))
378
+ for position in self.all_robot_positions:
379
+ index = self.env.find_index_from_coords(position)
380
+ if index == current_index.item():
381
+ occupied_node[index] = -1
382
+ else:
383
+ occupied_node[index] = 1
384
+
385
+ # node_inputs = np.concatenate((node_coords, node_utility_inputs, guidepost, occupied_node), axis=1)
386
+ node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
387
+ # node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
388
+ node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3)
389
+
390
+ # assert node_coords.shape[0] < self.node_padding_size
391
+ # padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0]))
392
+ # node_inputs = padding(node_inputs)
393
+
394
+ # node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device)
395
+ # node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to(
396
+ # self.device)
397
+ # node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1)
398
+ node_padding_mask = None
399
+
400
+ graph = list(graph.values())
401
+ edge_inputs = []
402
+ for node in graph:
403
+ node_edges = list(map(int, node))
404
+ edge_inputs.append(node_edges)
405
+
406
+ bias_matrix = self.calculate_edge_mask(edge_inputs)
407
+ edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
408
+
409
+ # assert len(edge_inputs) < self.node_padding_size
410
+ # padding = torch.nn.ConstantPad2d(
411
+ # (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1)
412
+ # edge_mask = padding(edge_mask)
413
+ # padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs)))
414
+
415
+ for edges in edge_inputs:
416
+ while len(edges) < self.k_size:
417
+ edges.append(0)
418
+
419
+ edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size)
420
+ # edge_inputs = padding2(edge_inputs)
421
+
422
+ edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
423
+ one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
424
+ edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
425
+
426
+ observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
427
+ return observations
428
+
429
+
430
+ def select_node(self, observations):
431
+ """ Forward pass through policy to get next position to go to on map """
432
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
433
+ with torch.no_grad():
434
+ logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask)
435
+
436
+ if self.greedy:
437
+ action_index = torch.argmax(logp_list, dim=1).long()
438
+ else:
439
+ action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
440
+
441
+ next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
442
+
443
+ next_position = self.env.node_coords[next_node_index]
444
+
445
+ return next_position, action_index
446
+
447
+ def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
448
+ """ Deconflict if 2 agents choose the same target position """
449
+ for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
450
+ moving_robot = self.robot_list[robot_id]
451
+ # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
452
+ # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
453
+ # k = 0
454
+ # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
455
+ # k += 1
456
+ # next_position = self.env.node_coords[dist_to_next_position[k]]
457
+
458
+ dist = np.linalg.norm(next_position - moving_robot.robot_position)
459
+ next_position_list[j] = next_position
460
+ dist_list[j] = dist
461
+ moving_robot.travel_dist += dist
462
+ moving_robot.robot_position = next_position
463
+
464
+ return next_position_list, dist_list
465
+
466
+ def work(self, currEpisode):
467
+ '''
468
+ Interacts with the environment. The agent gets either gradients or experience buffer
469
+ '''
470
+ self.run_episode(currEpisode)
471
+
472
+ def calculate_edge_mask(self, edge_inputs):
473
+ size = len(edge_inputs)
474
+ bias_matrix = np.ones((size, size))
475
+ for i in range(size):
476
+ for j in range(size):
477
+ if j in edge_inputs[i]:
478
+ bias_matrix[i][j] = 0
479
+ return bias_matrix
480
+
481
+ def make_gif(self, path, n):
482
+ """ Generate a gif given list of images """
483
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
484
+ fps=5) as writer:
485
+ for frame in self.env.frame_files:
486
+ image = imageio.imread(frame)
487
+ writer.append_data(image)
488
+ print('gif complete\n')
489
+
490
+ # Remove files
491
+ for filename in self.env.frame_files[:-1]:
492
+ os.remove(filename)
493
+
494
+ # For watershed segmenter gif during TTA
495
+ if self.tta:
496
+
497
+ # print("self.kmeans_clusterer.kmeans_frame_files", self.kmeans_clusterer.kmeans_frame_files)
498
+ with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
499
+ fps=5) as writer:
500
+ for frame in self.kmeans_clusterer.kmeans_frame_files:
501
+ image = imageio.imread(frame)
502
+ writer.append_data(image)
503
+ print('Kmeans Clusterer gif complete\n')
504
+
505
+ # Remove files
506
+ for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
507
+ os.remove(filename)
508
+
509
+ ################################################################################
510
+ # ADDED
511
+ ################################################################################
512
+
513
+ def log_metrics(self, step):
514
+ # Update tgt found metrics
515
+ if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
516
+ self.steps_to_first_tgt = step + 1
517
+ if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
518
+ self.steps_to_mid_tgt = step + 1
519
+ if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
520
+ self.steps_to_last_tgt = step + 1
521
+
522
+ # Update sim metrics
523
+ if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path):
524
+ side_dim = int(np.sqrt(self.env.segmentation_info_mask.shape[0]))
525
+ pred_mask = self.env.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
526
+ gt_mask = self.env.gt_segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
527
+ if step == 0:
528
+ self.sim_0perc = self.norm_distance(pred_mask, gt_mask)
529
+ elif step == int(NUM_EPS_STEPS * 0.25):
530
+ self.sim_25perc = self.norm_distance(pred_mask, gt_mask)
531
+ elif step == int(NUM_EPS_STEPS * 0.5):
532
+ self.sim_50perc = self.norm_distance(pred_mask, gt_mask)
533
+ elif step == int(NUM_EPS_STEPS * 0.75):
534
+ self.sim_75perc = self.norm_distance(pred_mask, gt_mask)
535
+ elif step == NUM_EPS_STEPS - 1:
536
+ self.sim_100perc = self.norm_distance(pred_mask, gt_mask)
537
+
538
+ def norm_distance(self, P, Q, norm_type="L2"):
539
+
540
+ # Normalize both grids to [0,1]
541
+ try:
542
+ if P.max() != P.min():
543
+ P_norm = (P - P.min()) / (P.max() - P.min())
544
+ else:
545
+ P_norm = P
546
+ if Q.max() != Q.min():
547
+ Q_norm = (Q - Q.min()) / (Q.max() - Q.min())
548
+ else:
549
+ Q_norm = Q
550
+ except FloatingPointError as e:
551
+ print(f"{RED}Caught floating point error:{NC} {e}")
552
+ print("P min/max:", P.min(), P.max())
553
+ print("Q min/max:", Q.min(), Q.max())
554
+ print("Q: ", Q)
555
+ similarity = None
556
+ return similarity
557
+
558
+ if norm_type == "L1":
559
+ num_cells = P.shape[0] * P.shape[1]
560
+ # L1 distance: sum of absolute differences
561
+ l1_dist = np.sum(np.abs(P_norm - Q_norm))
562
+ # Normalize: maximum L1 distance is num_cells (if every cell differs by 1)
563
+ similarity = 1 - (l1_dist / num_cells)
564
+ elif norm_type == "L2":
565
+ # L2 distance via Root Mean Squared Error (RMSE)
566
+ rmse = np.sqrt(np.mean((P_norm - Q_norm)**2))
567
+ # Since both grids are in [0,1], maximum RMSE is 1.
568
+ similarity = 1 - rmse
569
+ else:
570
+ raise ValueError("norm_type must be either 'L1' or 'L2'")
571
+
572
+ return similarity
573
+
574
+ def transpose_flat_idx(self, idx, H= NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
575
+ """
576
+ Given a flattened index X in an NxN matrix,
577
+ return the new index X' after transposing the matrix.
578
+ """
579
+ row = idx // W
580
+ col = idx % W
581
+ idx_T = col * H + row
582
+ return idx_T
583
+
584
+ def poisson_tta_update(self, robot, episode, step):
585
+
586
+ # TODO: Move into TTA loop to save computation
587
+ # Generate Kmeans Clusters Stats
588
+ visited_indices = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
589
+ region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
590
+ self.kmeans_sat_embeds_clusters,
591
+ self.clip_seg_tta.heatmap_unnormalized,
592
+ visited_indices,
593
+ episode_num=episode,
594
+ step_num=step
595
+ )
596
+
597
+ # Prep & execute TTA
598
+ self.step_since_tta += 1
599
+ if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: # Allow even if no positive at start
600
+
601
+ # for _ in range(NUM_TTA_STEPS):
602
+
603
+ filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
604
+ filt_targets_found_on_path = robot.targets_found_on_path
605
+ num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
606
+ # per_sample_weight_scale = [min(1.0, (step/num_cells)) for _ in filt_traj_coords]
607
+ # per_sample_weight_scale = [0.5 * min(1.0, (step/100)) for _ in filt_traj_coords]
608
+ # per_sample_weight_scale = [1.0 for _ in filt_traj_coords]
609
+ pos_sample_weight_scale, neg_sample_weight_scale = [], []
610
+ for i, sample_loc in enumerate(filt_traj_coords):
611
+ label = self.kmeans_clusterer.get_label_id(sample_loc)
612
+ num_patches = region_stats_dict[label]['num_patches']
613
+ patches_visited = region_stats_dict[label]['patches_visited']
614
+ expectation = region_stats_dict[label]['expectation']
615
+
616
+ ## BEST so far: exponent like focal loss to wait for more samples before confidently decreasing
617
+ pos_weight = 4.0 # 2.0
618
+ # pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(num_patches))**GAMMA_EXPONENT) # (1,5)
619
+ # pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT)
620
+ # neg_weight = min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT)
621
+ neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) # (0,1)
622
+ pos_sample_weight_scale.append(pos_weight)
623
+ neg_sample_weight_scale.append(neg_weight)
624
+
625
+ ## Prelim throughts (BAD - quickly reduce low probs region even with little samples)
626
+ # neg_weight = min(1.0, patches_visited/(3*expectation))
627
+ # local_probs = self.kmeans_clusterer.get_probs(sample_loc, self.clip_seg_tta.heatmap)
628
+ # neg_weight = min(local_probs/2, patches_visited/(3*expectation)) # 2*expectation (if don't want TTA scheduler - 3x TTA)
629
+ # neg_weight = min(1.0, patches_visited/num_patches) # 2*expectation (if don't want TTA scheduler - 3x TTA)
630
+
631
+ ## Hacky, but works better (does not decrase low probs region too fast)
632
+ # if label == 0:
633
+ # neg_weight = min(0.5, patches_visited/(3*expectation))
634
+ # else:
635
+ # neg_weight = min(0.05, patches_visited/(3*expectation))
636
+ # squared
637
+
638
+ # # # Adaptative LR (as samples increase, increase LR to fit more datapoints - else won't update)
639
+ adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
640
+ # print("!!! adaptive_lr", adaptive_lr)
641
+ # adaptive_lr = 2e-6
642
+
643
+ # NOTE: Not as good as adaptive LR (cos discrete)
644
+ # # Num TTA teps schedulerq
645
+ # min_tta_steps = 3
646
+ # max_tta_steps = 10
647
+ # num_tta_steps = int((max_tta_steps - min_tta_steps) * (step / num_cells) + min_tta_steps)
648
+ # print("!!! num_tta_steps", num_tta_steps)
649
+
650
+ # TTA Update
651
+ self.clip_seg_tta.execute_tta(
652
+ filt_traj_coords,
653
+ filt_targets_found_on_path,
654
+ tta_steps=NUM_TTA_STEPS,
655
+ lr=adaptive_lr,
656
+ pos_sample_weight=pos_sample_weight_scale,
657
+ neg_sample_weight=neg_sample_weight_scale,
658
+ modality=MODALITY,
659
+ query_variety=QUERY_VARIETY,
660
+ target_found_idxs=self.env.target_found_idxs,
661
+ reset_weights=RESET_WEIGHTS
662
+ )
663
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
664
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
665
+ self.step_since_tta = 0
666
+
667
+ ################################################################################
668
+
669
+
670
+ # if def main
671
+ if __name__ == "__main__":
672
+
673
+ # CHANGE ME!
674
+ currEpisode = 0
675
+
676
+ # Prepare the model
677
+ # device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu')
678
+ device = torch.device('cuda') if USE_GPU else torch.device('cpu')
679
+ policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
680
+ # script_dir = os.path.dirname(os.path.abspath(__file__))
681
+ script_dir = Path(__file__).resolve().parent
682
+ print("real_script_dir: ", script_dir)
683
+ # checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}')
684
+ checkpoint = torch.load(f'{model_path}/{MODEL_NAME}')
685
+ policy_net.load_state_dict(checkpoint['policy_model'])
686
+ print('Model loaded!')
687
+ # print(next(policy_net.parameters()).device)
688
+
689
+ # Init Taxabind here (only need to init once)
690
+ if TAXABIND_TTA:
691
+ # self.clip_seg_tta = None
692
+ clip_seg_tta = ClipSegTTA(
693
+ img_dir=TAXABIND_IMG_DIR,
694
+ imo_dir=TAXABIND_IMO_DIR,
695
+ json_path=TAXABIND_INAT_JSON_PATH,
696
+ sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
697
+ patch_size=TAXABIND_PATCH_SIZE,
698
+ sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
699
+ sample_index = 0, # Set using 'reset' in worker
700
+ blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
701
+ device=device,
702
+ sat_to_img_ids_json_is_train_dict=False, # for search ds val
703
+ tax_to_filter_val=QUERY_TAX,
704
+ load_model=USE_CLIP_PREDS,
705
+ initial_modality=INITIAL_MODALITY,
706
+ sound_data_path = TAXABIND_SOUND_DATA_PATH,
707
+ sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
708
+ # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
709
+ )
710
+ print("ClipSegTTA Loaded!")
711
+ else:
712
+ clip_seg_tta = None
713
+
714
+ # Define TestWorker
715
+ planner = TestWorker(
716
+ meta_agent_id=0,
717
+ n_agent=1,
718
+ policy_net=policy_net,
719
+ global_step=3,
720
+ device='cuda',
721
+ greedy=True,
722
+ save_image=SAVE_GIFS,
723
+ clip_seg_tta=clip_seg_tta
724
+ )
725
+ planner.run_episode(currEpisode)
726
+
727
+ print("planner.perf_metrics: ", planner.perf_metrics)
test_parameter.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################################################
2
+ # Name: test_parameter.py
3
+ #
4
+ # NOTE: Change all your hyper-params here!
5
+ # Simple How-To Guide:
6
+ # 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True
7
+ # 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False
8
+ # 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False
9
+ ############################################################################################
10
+
11
+ import os
12
+ import sys
13
+ sys.modules['TRAINING'] = False # False = Inference Testing
14
+
15
+ ###############################################################
16
+ OPT_VARS = {}
17
+ def getenv(var_name, default=None, cast_type=str):
18
+ try:
19
+ value = os.environ.get(var_name, None)
20
+ if value is None:
21
+ result = default
22
+ elif cast_type == bool:
23
+ result = value.lower() in ("true", "1", "yes")
24
+ else:
25
+ result = cast_type(value)
26
+ except (ValueError, TypeError):
27
+ result = default
28
+
29
+ OPT_VARS[var_name] = result # Log the result
30
+ return result
31
+ ###############################################################
32
+
33
+ POLICY = getenv("POLICY", default="RL", cast_type=str)
34
+ # TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back
35
+
36
+ NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val
37
+ NUM_RUN = 1
38
+ SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
39
+ SAVE_TRAJECTORY = False # do you want to save per-step metrics
40
+ SAVE_LENGTH = False # do you want to save per-episode metrics
41
+ VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
42
+ # MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth
43
+ # MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
44
+ # MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
45
+ # MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth
46
+ MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth"
47
+
48
+ NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int)
49
+ TERMINATE_ON_TGTS_FOUND = False # Whether to terminate episode when all targets found
50
+ FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
51
+ FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
52
+
53
+ ## Whether to override initial score mask from CLIP
54
+ USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
55
+ OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str)
56
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia"
57
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla"
58
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae"
59
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales"
60
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in"
61
+ # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out"
62
+
63
+ # Used to calcultae info_gain metric
64
+ OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str)
65
+ # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map"
66
+ # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map"
67
+
68
+ #######################################################################
69
+ # iNAT TTA
70
+ #######################################################################
71
+
72
+ # Query Params
73
+ QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax
74
+ # QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in
75
+ # QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in
76
+ # QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out
77
+ # QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out
78
+
79
+ # TTA PARAMS
80
+ EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
81
+ STEPS_PER_TTA = 20 # no. steps before each TTA series
82
+ NUM_TTA_STEPS = 1 # no. of TTA steps during each series
83
+ INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined"
84
+ MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined"
85
+ QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined"
86
+ RESET_WEIGHTS = True
87
+ MIN_LR = 1e-6
88
+ MAX_LR = 1e-5 # 1e-5
89
+ GAMMA_EXPONENT = 2 # 2
90
+
91
+ # Paths related to taxabind (TRAIN w/ TARGETS)
92
+ TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True
93
+ TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
94
+ TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
95
+ TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
96
+ TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
97
+ # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
98
+ TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str)
99
+ # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json"
100
+ # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json"
101
+ TAXABIND_PATCH_SIZE=14
102
+ TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt"
103
+ TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5)
104
+ TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8
105
+
106
+ # Sound
107
+ TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test'
108
+ TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt"
109
+
110
+ # # Paths related to taxabind (TRAIN w/ TARGETS)
111
+ # TAXABIND_TTA = True
112
+ # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
113
+ # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
114
+ # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
115
+ # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
116
+ # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
117
+ # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json"
118
+ # TAXABIND_PATCH_SIZE=14
119
+ # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
120
+ # TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141
121
+
122
+ # # Paths related to taxabind (VAL)
123
+ # TAXABIND_TTA = True
124
+ # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
125
+ # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
126
+ # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
127
+ # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
128
+ # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
129
+ # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json"
130
+ # TAXABIND_PATCH_SIZE=14
131
+ # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
132
+ # TAXABIND_SAMPLE_INDEX = 45 # TEMP
133
+
134
+ #######################################################################
135
+ # Pretraining
136
+ #######################################################################
137
+
138
+ # TODO: Get rid of the LISA stuff...
139
+ # If LISA trained clss
140
+ GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss"
141
+ MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
142
+ TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR
143
+ RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
144
+
145
+ # # If LISA out clss
146
+ # GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss"
147
+ # MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
148
+ # TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR
149
+ # RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
150
+
151
+ #######################################################################
152
+
153
+ NUM_ROBOTS = 1
154
+ NUM_COORDS_WIDTH=24 # How many node coords across width?
155
+ NUM_COORDS_HEIGHT=24 # How many node coords across height?
156
+ HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info)
157
+
158
+ SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
159
+ SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular)
160
+
161
+ INPUT_DIM = 4
162
+ EMBEDDING_DIM = 128
163
+ K_SIZE = 8 # 8
164
+
165
+ USE_GPU = True # do you want to use GPUS?
166
+ NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
167
+ NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
168
+ FOLDER_NAME = 'inference'
169
+ model_path = f'{FOLDER_NAME}/model'
170
+ gifs_path = f'{FOLDER_NAME}/test_results/gifs'
171
+ trajectory_path = f'{FOLDER_NAME}/test_results/trajectory'
172
+ length_path = f'{FOLDER_NAME}/test_results/length'
173
+ log_path = f'{FOLDER_NAME}/test_results/log'
174
+ CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
175
+ # trajectory_path = f'results/trajectory'
176
+ # length_path = f'results/length'
177
+
178
+ # COLORS (for printing)
179
+ RED='\033[1;31m'
180
+ GREEN='\033[1;32m'
181
+ YELLOW='\033[1;93m'
182
+ NC_BOLD='\033[1m' # Bold, No Color
183
+ NC='\033[0m' # No Color