derektan commited on
Commit
575e58e
·
1 Parent(s): 88b55c7

Try saving heatmap outside clip_seg_tta

Browse files
Taxabind/TaxaBind/SatBind/clip_seg_tta.py CHANGED
@@ -40,6 +40,7 @@ from types import SimpleNamespace
40
  import torch.nn as nn
41
  import spaces # integration with ZeroGPU on hf
42
  from torch.autograd import enable_grad # handy alias
 
43
 
44
  # import matplotlib
45
  # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
@@ -416,6 +417,9 @@ class ClipSegTTA:
416
  self.tta_time = time.time() - start_time
417
  # print("self.tta_time: ", self.tta_time)
418
 
 
 
 
419
  # Visualization every 'num_viz_steps' steps (if enabled)
420
  if (step + 1) % num_viz_steps == 0 and viz_heatmap:
421
  # Visualize only the first sample in the batch
@@ -430,6 +434,9 @@ class ClipSegTTA:
430
  species_name=self.species_name
431
  )
432
 
 
 
 
433
  ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
434
  # if self.device.type == "cuda":
435
  print("Deallocating models from GPU...")
@@ -438,13 +445,8 @@ class ClipSegTTA:
438
  self.model_local.imo_encoder.to(self.device)
439
  self.model_local.bio_model.to(self.device)
440
 
441
- # Move tensors to CPU before generating heatmap to avoid dtype/device mismatches
442
- img_cpu = img.to(self.device) if isinstance(img, torch.Tensor) else img
443
- imo_cpu = imo.to(self.device)
444
- sound_cpu = sound.to(self.device) if (sound is not None and isinstance(sound, torch.Tensor)) else sound
445
 
446
- # Save final heatmap after TTA steps
447
- self.generate_heatmap(img_cpu, imo_cpu, txt, sound=sound_cpu, modality=modality)
448
 
449
 
450
  def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
 
40
  import torch.nn as nn
41
  import spaces # integration with ZeroGPU on hf
42
  from torch.autograd import enable_grad # handy alias
43
+ import copy
44
 
45
  # import matplotlib
46
  # matplotlib.use("Agg") # <-- key line to avoid tkinter dependency
 
417
  self.tta_time = time.time() - start_time
418
  # print("self.tta_time: ", self.tta_time)
419
 
420
+ # Make deep copy of self.model_local.imo_encoder
421
+ self.model_local.imo_encoder = copy.deepcopy(self.model_local.imo_encoder)
422
+
423
  # Visualization every 'num_viz_steps' steps (if enabled)
424
  if (step + 1) % num_viz_steps == 0 and viz_heatmap:
425
  # Visualize only the first sample in the batch
 
434
  species_name=self.species_name
435
  )
436
 
437
+ # Save final heatmap after TTA steps
438
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
439
+
440
  ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
441
  # if self.device.type == "cuda":
442
  print("Deallocating models from GPU...")
 
445
  self.model_local.imo_encoder.to(self.device)
446
  self.model_local.bio_model.to(self.device)
447
 
448
+ return self.heatmap
 
 
 
449
 
 
 
450
 
451
 
452
  def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
test_multi_robot_worker.py CHANGED
@@ -653,7 +653,7 @@ class TestWorker:
653
  # print("!!! num_tta_steps", num_tta_steps)
654
 
655
  # TTA Update
656
- self.clip_seg_tta.execute_tta(
657
  filt_traj_coords,
658
  filt_targets_found_on_path,
659
  tta_steps=NUM_TTA_STEPS,
@@ -665,6 +665,9 @@ class TestWorker:
665
  target_found_idxs=self.env.target_found_idxs,
666
  reset_weights=RESET_WEIGHTS
667
  )
 
 
 
668
  self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
669
  self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
670
  self.step_since_tta = 0
 
653
  # print("!!! num_tta_steps", num_tta_steps)
654
 
655
  # TTA Update
656
+ heatmap = self.clip_seg_tta.execute_tta(
657
  filt_traj_coords,
658
  filt_targets_found_on_path,
659
  tta_steps=NUM_TTA_STEPS,
 
665
  target_found_idxs=self.env.target_found_idxs,
666
  reset_weights=RESET_WEIGHTS
667
  )
668
+
669
+ self.clip_seg_tta.heatmap = heatmap
670
+
671
  self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
672
  self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
673
  self.step_since_tta = 0