Set more precise shape to the attention weights and outputs
Browse filesThis PR sets more precise shape to the attention's weights, biases, and the outputs.
This was not an issue when the embedding size is a multiple of senses. However, when the hyperparamters are picked so that the embedding size is no longer the multiple of the number of senses, it would cause different number of parameters in `encoded` before and after reshaping. I found this mismatching when testing with a larger model with embedding size = 1280 and number of senses at 48. I received the following error:
```
- sense_weight_net.c_attn.bias: found shape torch.Size([2496]) in the checkpoint and torch.Size([2560]) in the model instantiated
- sense_weight_net.c_attn.weight: found shape torch.Size([2496, 1280]) in the checkpoint and torch.Size([2560, 1280]) in the model instantiated
```
This PR addresses this issue. Of course, a better solution should be recommending/enforcing the embedding size to be a full multiple of senses.
@@ -101,13 +101,14 @@ class BackpackWeightNetwork(nn.Module):
|
|
101 |
super().__init__()
|
102 |
self.n_embd = embed_dim
|
103 |
self.num_senses = num_senses
|
104 |
-
self.
|
|
|
105 |
self.softmax_scale = None
|
106 |
|
107 |
def forward(self, encoded):
|
108 |
b, s, d = encoded.shape
|
109 |
encoded = self.c_attn(encoded) # (b, s, 2*d)
|
110 |
-
encoded = encoded.reshape(b, s, 2, self.num_senses,
|
111 |
batch_size, seqlen = encoded.shape[0], encoded.shape[1]
|
112 |
|
113 |
# compute scores & mask
|
|
|
101 |
super().__init__()
|
102 |
self.n_embd = embed_dim
|
103 |
self.num_senses = num_senses
|
104 |
+
self.embed_per_sense = embed_dim // num_senses
|
105 |
+
self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
|
106 |
self.softmax_scale = None
|
107 |
|
108 |
def forward(self, encoded):
|
109 |
b, s, d = encoded.shape
|
110 |
encoded = self.c_attn(encoded) # (b, s, 2*d)
|
111 |
+
encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv)
|
112 |
batch_size, seqlen = encoded.shape[0], encoded.shape[1]
|
113 |
|
114 |
# compute scores & mask
|