derektan commited on
Commit
88b55c7
·
1 Parent(s): 51adbef

Gen heatmap using CPU

Browse files
Taxabind/TaxaBind/SatBind/clip_seg_tta.py CHANGED
@@ -438,8 +438,13 @@ class ClipSegTTA:
438
  self.model_local.imo_encoder.to(self.device)
439
  self.model_local.bio_model.to(self.device)
440
 
 
 
 
 
 
441
  # Save final heatmap after TTA steps
442
- self.generate_heatmap(img, imo, txt, sound=sound, modality=modality)
443
 
444
 
445
  def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
 
438
  self.model_local.imo_encoder.to(self.device)
439
  self.model_local.bio_model.to(self.device)
440
 
441
+ # Move tensors to CPU before generating heatmap to avoid dtype/device mismatches
442
+ img_cpu = img.to(self.device) if isinstance(img, torch.Tensor) else img
443
+ imo_cpu = imo.to(self.device)
444
+ sound_cpu = sound.to(self.device) if (sound is not None and isinstance(sound, torch.Tensor)) else sound
445
+
446
  # Save final heatmap after TTA steps
447
+ self.generate_heatmap(img_cpu, imo_cpu, txt, sound=sound_cpu, modality=modality)
448
 
449
 
450
  def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):