remove cuda
Browse files- generic_utils.py +5 -5
generic_utils.py
CHANGED
@@ -24,10 +24,10 @@ def show_cam_on_image(img, mask):
|
|
24 |
|
25 |
|
26 |
# initialize ViT pretrained
|
27 |
-
model = vit_LRP(pretrained=True)
|
28 |
model.eval()
|
29 |
attribution_generator = LRP(model)
|
30 |
-
model_baseline = vit(pretrained=True)
|
31 |
model_baseline.eval()
|
32 |
baselines_generator = Baselines(model_baseline)
|
33 |
|
@@ -37,16 +37,16 @@ def generate_visualization(
|
|
37 |
):
|
38 |
if LRP:
|
39 |
transformer_attribution = attribution_generator.generate_LRP(
|
40 |
-
original_image.unsqueeze(0)
|
41 |
).detach()
|
42 |
else:
|
43 |
if method == "gradcam":
|
44 |
transformer_attribution = baselines_generator.generate_cam_attn(
|
45 |
-
original_image.unsqueeze(0)
|
46 |
).detach()
|
47 |
else:
|
48 |
transformer_attribution = baselines_generator.generate_rollout(
|
49 |
-
original_image.unsqueeze(0)
|
50 |
).detach()
|
51 |
if method != "full":
|
52 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
|
|
24 |
|
25 |
|
26 |
# initialize ViT pretrained
|
27 |
+
model = vit_LRP(pretrained=True)
|
28 |
model.eval()
|
29 |
attribution_generator = LRP(model)
|
30 |
+
model_baseline = vit(pretrained=True)
|
31 |
model_baseline.eval()
|
32 |
baselines_generator = Baselines(model_baseline)
|
33 |
|
|
|
37 |
):
|
38 |
if LRP:
|
39 |
transformer_attribution = attribution_generator.generate_LRP(
|
40 |
+
original_image.unsqueeze(0), method=method, index=class_index
|
41 |
).detach()
|
42 |
else:
|
43 |
if method == "gradcam":
|
44 |
transformer_attribution = baselines_generator.generate_cam_attn(
|
45 |
+
original_image.unsqueeze(0), index=class_index
|
46 |
).detach()
|
47 |
else:
|
48 |
transformer_attribution = baselines_generator.generate_rollout(
|
49 |
+
original_image.unsqueeze(0)
|
50 |
).detach()
|
51 |
if method != "full":
|
52 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|