Update modeling_intern_vit.py
Browse files- modeling_intern_vit.py +9 -2
modeling_intern_vit.py
CHANGED
|
@@ -129,6 +129,12 @@ except Exception:
|
|
| 129 |
pass
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
class InternVisionEmbeddings(nn.Module):
|
| 133 |
def __init__(self, config: InternVisionConfig):
|
| 134 |
super().__init__()
|
|
@@ -267,11 +273,12 @@ class InternVisionEncoderLayer(nn.Module):
|
|
| 267 |
super().__init__()
|
| 268 |
self.embed_dim = config.hidden_size
|
| 269 |
self.intermediate_size = config.intermediate_size
|
|
|
|
| 270 |
|
| 271 |
self.attn = InternAttention(config)
|
| 272 |
self.mlp = InternMLP(config)
|
| 273 |
-
self.norm1 =
|
| 274 |
-
self.norm2 =
|
| 275 |
|
| 276 |
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 277 |
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
|
|
|
| 129 |
pass
|
| 130 |
|
| 131 |
|
| 132 |
+
NORM2FN = {
|
| 133 |
+
'rms_norm': InternRMSNorm,
|
| 134 |
+
'layer_norm': nn.LayerNorm,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
class InternVisionEmbeddings(nn.Module):
|
| 139 |
def __init__(self, config: InternVisionConfig):
|
| 140 |
super().__init__()
|
|
|
|
| 273 |
super().__init__()
|
| 274 |
self.embed_dim = config.hidden_size
|
| 275 |
self.intermediate_size = config.intermediate_size
|
| 276 |
+
self.norm_type = config.norm_type
|
| 277 |
|
| 278 |
self.attn = InternAttention(config)
|
| 279 |
self.mlp = InternMLP(config)
|
| 280 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
| 281 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
| 282 |
|
| 283 |
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 284 |
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|