Tousifahamed commited on
Commit
0f122cf
·
verified ·
1 Parent(s): 85c27c2

Upload 2 files

Browse files
Files changed (1) hide show
  1. model_utils.py +13 -42
model_utils.py CHANGED
@@ -17,53 +17,24 @@ class Block(nn.Module):
17
  def __init__(self, config):
18
  super().__init__()
19
  self.ln_1 = nn.LayerNorm(config.n_embd)
20
- self.attn = MultiHeadAttention(config)
21
  self.ln_2 = nn.LayerNorm(config.n_embd)
22
- self.mlp = FeedForward(config)
 
 
 
 
 
 
23
 
24
  def forward(self, x):
25
- x = x + self.attn(self.ln_1(x))
26
  x = x + self.mlp(self.ln_2(x))
27
  return x
28
-
29
- class MultiHeadAttention(nn.Module):
30
- def __init__(self, config):
31
- super().__init__()
32
- self.n_head = config.n_head
33
- self.n_embd = config.n_embd
34
- assert self.n_embd % self.n_head == 0
35
-
36
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
37
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
38
- self.dropout = nn.Dropout(config.dropout)
39
-
40
- def forward(self, x):
41
- B, T, C = x.size()
42
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
43
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
44
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
45
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
46
-
47
- att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1))))
48
- att = F.softmax(att, dim=-1)
49
- att = self.dropout(att)
50
- y = att @ v
51
- y = y.transpose(1, 2).contiguous().view(B, T, C)
52
- return self.c_proj(y)
53
-
54
- class FeedForward(nn.Module):
55
- def __init__(self, config):
56
- super().__init__()
57
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
58
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
59
- self.dropout = nn.Dropout(config.dropout)
60
-
61
- def forward(self, x):
62
- x = F.gelu(self.c_fc(x))
63
- x = self.dropout(x)
64
- x = self.c_proj(x)
65
- x = self.dropout(x)
66
- return x
67
 
68
  class GPT(nn.Module):
69
  def __init__(self, config):
 
17
  def __init__(self, config):
18
  super().__init__()
19
  self.ln_1 = nn.LayerNorm(config.n_embd)
20
+ self.attn = nn.MultiheadAttention(config.n_embd, config.n_head, dropout=config.dropout, batch_first=True)
21
  self.ln_2 = nn.LayerNorm(config.n_embd)
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(config.n_embd, 4 * config.n_embd),
24
+ nn.GELU(),
25
+ nn.Dropout(config.dropout),
26
+ nn.Linear(4 * config.n_embd, config.n_embd),
27
+ nn.Dropout(config.dropout),
28
+ )
29
 
30
  def forward(self, x):
31
+ x = x + self._attention_block(self.ln_1(x))
32
  x = x + self.mlp(self.ln_2(x))
33
  return x
34
+
35
+ def _attention_block(self, x):
36
+ attn_output, _ = self.attn(x, x, x)
37
+ return attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class GPT(nn.Module):
40
  def __init__(self, config):