Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
88b55c7
1
Parent(s):
51adbef
Gen heatmap using CPU
Browse files
Taxabind/TaxaBind/SatBind/clip_seg_tta.py
CHANGED
@@ -438,8 +438,13 @@ class ClipSegTTA:
|
|
438 |
self.model_local.imo_encoder.to(self.device)
|
439 |
self.model_local.bio_model.to(self.device)
|
440 |
|
|
|
|
|
|
|
|
|
|
|
441 |
# Save final heatmap after TTA steps
|
442 |
-
self.generate_heatmap(
|
443 |
|
444 |
|
445 |
def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
|
|
|
438 |
self.model_local.imo_encoder.to(self.device)
|
439 |
self.model_local.bio_model.to(self.device)
|
440 |
|
441 |
+
# Move tensors to CPU before generating heatmap to avoid dtype/device mismatches
|
442 |
+
img_cpu = img.to(self.device) if isinstance(img, torch.Tensor) else img
|
443 |
+
imo_cpu = imo.to(self.device)
|
444 |
+
sound_cpu = sound.to(self.device) if (sound is not None and isinstance(sound, torch.Tensor)) else sound
|
445 |
+
|
446 |
# Save final heatmap after TTA steps
|
447 |
+
self.generate_heatmap(img_cpu, imo_cpu, txt, sound=sound_cpu, modality=modality)
|
448 |
|
449 |
|
450 |
def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
|