Commit
·
952897b
1
Parent(s):
b845577
feat: add autocasting in vision.patch_embed
Browse files- eva_model.py +2 -1
eva_model.py
CHANGED
|
@@ -462,13 +462,14 @@ class PatchEmbed(nn.Module):
|
|
| 462 |
)
|
| 463 |
|
| 464 |
def forward(self, x, **kwargs):
|
|
|
|
| 465 |
B, C, H, W = x.shape
|
| 466 |
# FIXME look at relaxing size constraints
|
| 467 |
assert H == self.img_size[0] and W == self.img_size[1], (
|
| 468 |
f"Input image size ({H}*{W}) doesn't match model "
|
| 469 |
f'({self.img_size[0]}*{self.img_size[1]}).'
|
| 470 |
)
|
| 471 |
-
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 472 |
return x
|
| 473 |
|
| 474 |
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
def forward(self, x, **kwargs):
|
| 465 |
+
target_dtype = self.proj.weight.dtype
|
| 466 |
B, C, H, W = x.shape
|
| 467 |
# FIXME look at relaxing size constraints
|
| 468 |
assert H == self.img_size[0] and W == self.img_size[1], (
|
| 469 |
f"Input image size ({H}*{W}) doesn't match model "
|
| 470 |
f'({self.img_size[0]}*{self.img_size[1]}).'
|
| 471 |
)
|
| 472 |
+
x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
|
| 473 |
return x
|
| 474 |
|
| 475 |
|