AustingDong
commited on
Commit
·
e788822
1
Parent(s):
8235fd2
finished baseline
Browse files- demo/cam.py +6 -1
- demo/model_utils.py +1 -1
demo/cam.py
CHANGED
|
@@ -535,7 +535,11 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
| 535 |
elif focus == "Language Model":
|
| 536 |
self.model.zero_grad()
|
| 537 |
# print(outputs_raw)
|
| 538 |
-
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
loss.backward()
|
| 540 |
|
| 541 |
|
|
@@ -556,6 +560,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
| 556 |
|
| 557 |
grad = F.relu(grad)
|
| 558 |
|
|
|
|
| 559 |
cam = act * grad # shape: [1, heads, seq_len, seq_len]
|
| 560 |
cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
|
| 561 |
cam = cam.to(torch.float32).detach().cpu()
|
|
|
|
| 535 |
elif focus == "Language Model":
|
| 536 |
self.model.zero_grad()
|
| 537 |
# print(outputs_raw)
|
| 538 |
+
# loss = outputs_raw.logits.max(dim=-1).values.sum()
|
| 539 |
+
if class_idx == -1:
|
| 540 |
+
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
| 541 |
+
else:
|
| 542 |
+
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
|
| 543 |
loss.backward()
|
| 544 |
|
| 545 |
|
|
|
|
| 560 |
|
| 561 |
grad = F.relu(grad)
|
| 562 |
|
| 563 |
+
# cam = grad
|
| 564 |
cam = act * grad # shape: [1, heads, seq_len, seq_len]
|
| 565 |
cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
|
| 566 |
cam = cam.to(torch.float32).detach().cpu()
|
demo/model_utils.py
CHANGED
|
@@ -204,7 +204,7 @@ class ChartGemma_Utils(Model_Utils):
|
|
| 204 |
self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
|
| 205 |
model_path,
|
| 206 |
torch_dtype=torch.float16,
|
| 207 |
-
attn_implementation="
|
| 208 |
output_attentions=True
|
| 209 |
)
|
| 210 |
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|
|
|
|
| 204 |
self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
|
| 205 |
model_path,
|
| 206 |
torch_dtype=torch.float16,
|
| 207 |
+
attn_implementation="sdpa",
|
| 208 |
output_attentions=True
|
| 209 |
)
|
| 210 |
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|