ivanzhouyq commited on
Commit
b2c5167
·
1 Parent(s): 988edaf

Set more precise shape to the attention weights and outputs

Browse files

This PR sets more precise shape to the attention's weights, biases, and the outputs.

This was not an issue when the embedding size is a multiple of senses. However, when the hyperparamters are picked so that the embedding size is no longer the multiple of the number of senses, it would cause different number of parameters in `encoded` before and after reshaping. I found this mismatching when testing with a larger model with embedding size = 1280 and number of senses at 48. I received the following error:

```
- sense_weight_net.c_attn.bias: found shape torch.Size([2496]) in the checkpoint and torch.Size([2560]) in the model instantiated
- sense_weight_net.c_attn.weight: found shape torch.Size([2496, 1280]) in the checkpoint and torch.Size([2560, 1280]) in the model instantiated
```

This PR addresses this issue. Of course, a better solution should be recommending/enforcing the embedding size to be a full multiple of senses.

Files changed (1) hide show
  1. modeling_backpack_gpt2.py +3 -2
modeling_backpack_gpt2.py CHANGED
@@ -101,13 +101,14 @@ class BackpackWeightNetwork(nn.Module):
101
  super().__init__()
102
  self.n_embd = embed_dim
103
  self.num_senses = num_senses
104
- self.c_attn = nn.Linear(embed_dim, 2*embed_dim)
 
105
  self.softmax_scale = None
106
 
107
  def forward(self, encoded):
108
  b, s, d = encoded.shape
109
  encoded = self.c_attn(encoded) # (b, s, 2*d)
110
- encoded = encoded.reshape(b, s, 2, self.num_senses, d // self.num_senses) #(b, s, 2, nv, d//nv)
111
  batch_size, seqlen = encoded.shape[0], encoded.shape[1]
112
 
113
  # compute scores & mask
 
101
  super().__init__()
102
  self.n_embd = embed_dim
103
  self.num_senses = num_senses
104
+ self.embed_per_sense = embed_dim // num_senses
105
+ self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
106
  self.softmax_scale = None
107
 
108
  def forward(self, encoded):
109
  b, s, d = encoded.shape
110
  encoded = self.c_attn(encoded) # (b, s, 2*d)
111
+ encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv)
112
  batch_size, seqlen = encoded.shape[0], encoded.shape[1]
113
 
114
  # compute scores & mask