Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
575e58e
1
Parent(s):
88b55c7
Try saving heatmap outside clip_seg_tta
Browse files
Taxabind/TaxaBind/SatBind/clip_seg_tta.py
CHANGED
@@ -40,6 +40,7 @@ from types import SimpleNamespace
|
|
40 |
import torch.nn as nn
|
41 |
import spaces # integration with ZeroGPU on hf
|
42 |
from torch.autograd import enable_grad # handy alias
|
|
|
43 |
|
44 |
# import matplotlib
|
45 |
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
@@ -416,6 +417,9 @@ class ClipSegTTA:
|
|
416 |
self.tta_time = time.time() - start_time
|
417 |
# print("self.tta_time: ", self.tta_time)
|
418 |
|
|
|
|
|
|
|
419 |
# Visualization every 'num_viz_steps' steps (if enabled)
|
420 |
if (step + 1) % num_viz_steps == 0 and viz_heatmap:
|
421 |
# Visualize only the first sample in the batch
|
@@ -430,6 +434,9 @@ class ClipSegTTA:
|
|
430 |
species_name=self.species_name
|
431 |
)
|
432 |
|
|
|
|
|
|
|
433 |
## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
|
434 |
# if self.device.type == "cuda":
|
435 |
print("Deallocating models from GPU...")
|
@@ -438,13 +445,8 @@ class ClipSegTTA:
|
|
438 |
self.model_local.imo_encoder.to(self.device)
|
439 |
self.model_local.bio_model.to(self.device)
|
440 |
|
441 |
-
|
442 |
-
img_cpu = img.to(self.device) if isinstance(img, torch.Tensor) else img
|
443 |
-
imo_cpu = imo.to(self.device)
|
444 |
-
sound_cpu = sound.to(self.device) if (sound is not None and isinstance(sound, torch.Tensor)) else sound
|
445 |
|
446 |
-
# Save final heatmap after TTA steps
|
447 |
-
self.generate_heatmap(img_cpu, imo_cpu, txt, sound=sound_cpu, modality=modality)
|
448 |
|
449 |
|
450 |
def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
|
|
|
40 |
import torch.nn as nn
|
41 |
import spaces # integration with ZeroGPU on hf
|
42 |
from torch.autograd import enable_grad # handy alias
|
43 |
+
import copy
|
44 |
|
45 |
# import matplotlib
|
46 |
# matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
|
|
|
417 |
self.tta_time = time.time() - start_time
|
418 |
# print("self.tta_time: ", self.tta_time)
|
419 |
|
420 |
+
# Make deep copy of self.model_local.imo_encoder
|
421 |
+
self.model_local.imo_encoder = copy.deepcopy(self.model_local.imo_encoder)
|
422 |
+
|
423 |
# Visualization every 'num_viz_steps' steps (if enabled)
|
424 |
if (step + 1) % num_viz_steps == 0 and viz_heatmap:
|
425 |
# Visualize only the first sample in the batch
|
|
|
434 |
species_name=self.species_name
|
435 |
)
|
436 |
|
437 |
+
# Save final heatmap after TTA steps
|
438 |
+
self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
|
439 |
+
|
440 |
## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
|
441 |
# if self.device.type == "cuda":
|
442 |
print("Deallocating models from GPU...")
|
|
|
445 |
self.model_local.imo_encoder.to(self.device)
|
446 |
self.model_local.bio_model.to(self.device)
|
447 |
|
448 |
+
return self.heatmap
|
|
|
|
|
|
|
449 |
|
|
|
|
|
450 |
|
451 |
|
452 |
def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
|
test_multi_robot_worker.py
CHANGED
@@ -653,7 +653,7 @@ class TestWorker:
|
|
653 |
# print("!!! num_tta_steps", num_tta_steps)
|
654 |
|
655 |
# TTA Update
|
656 |
-
self.clip_seg_tta.execute_tta(
|
657 |
filt_traj_coords,
|
658 |
filt_targets_found_on_path,
|
659 |
tta_steps=NUM_TTA_STEPS,
|
@@ -665,6 +665,9 @@ class TestWorker:
|
|
665 |
target_found_idxs=self.env.target_found_idxs,
|
666 |
reset_weights=RESET_WEIGHTS
|
667 |
)
|
|
|
|
|
|
|
668 |
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
669 |
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
670 |
self.step_since_tta = 0
|
|
|
653 |
# print("!!! num_tta_steps", num_tta_steps)
|
654 |
|
655 |
# TTA Update
|
656 |
+
heatmap = self.clip_seg_tta.execute_tta(
|
657 |
filt_traj_coords,
|
658 |
filt_targets_found_on_path,
|
659 |
tta_steps=NUM_TTA_STEPS,
|
|
|
665 |
target_found_idxs=self.env.target_found_idxs,
|
666 |
reset_weights=RESET_WEIGHTS
|
667 |
)
|
668 |
+
|
669 |
+
self.clip_seg_tta.heatmap = heatmap
|
670 |
+
|
671 |
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
672 |
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
673 |
self.step_since_tta = 0
|