Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.init import trunc_normal_ | |
| from torch.nn.utils import weight_norm | |
| class DINOHead(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| use_bn=False, | |
| nlayers=3, | |
| hidden_dim=2048, | |
| bottleneck_dim=256, | |
| mlp_bias=True, | |
| ): | |
| super().__init__() | |
| nlayers = max(nlayers, 1) | |
| self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) | |
| self.apply(self._init_weights) | |
| self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) | |
| self.last_layer.weight_g.data.fill_(1) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| x = self.mlp(x) | |
| eps = 1e-6 if x.dtype == torch.float16 else 1e-12 | |
| x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) | |
| x = self.last_layer(x) | |
| return x | |
| def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): | |
| if nlayers == 1: | |
| return nn.Linear(in_dim, bottleneck_dim, bias=bias) | |
| else: | |
| layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] | |
| if use_bn: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(nn.GELU()) | |
| for _ in range(nlayers - 2): | |
| layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) | |
| if use_bn: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(nn.GELU()) | |
| layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) | |
| return nn.Sequential(*layers) | |