File size: 1,372 Bytes
10a0e43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch.nn as nn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PatchEmbed3D(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(
self,
patch_size=16,
tubelet_size=2,
in_chans=3,
embed_dim=768,
):
super().__init__()
self.patch_size = patch_size
self.tubelet_size = tubelet_size
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(tubelet_size, patch_size, patch_size),
stride=(tubelet_size, patch_size, patch_size),
)
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
|