File size: 7,515 Bytes
787be42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

import torch
import argparse
from torch.nn import functional as F
import time
from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock
torch.manual_seed(1337)

def get_batch(batch_size, dataset, block_size):
    sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1))
    xb  = torch.zeros(batch_size,block_size, dtype=torch.long)
    yb = torch.zeros(batch_size,block_size, dtype=torch.long)
    for idx, sample_index in enumerate(sample):
        xb[idx,:] = dataset[sample_index:sample_index+block_size]
        yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1]
    return xb, yb

@torch.no_grad()
def eval(model, batch_size, block_size, dataset):
    xb, yb = get_batch(batch_size, dataset, block_size)
    logits, loss = model(xb, yb)
    return loss.item()

def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps):
    sumloss = 0
    for _ in range(1,steps+1):
        xb, yb = get_batch(batch_size, train_ds, block_size)
        logits, loss = model(xb, yb)
        sumloss += loss.item()
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if _ % 1000 == 0:
            val_loss = eval(model, 30, block_size, val_ds,)
            print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}")

            sumloss = 0

class Transformer(torch.nn.Module):
    def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None:
        super().__init__()
        self.block_size=block_size
        self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
        self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
        self.tf_blocks = torch.nn.Sequential(
            *[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)]
        )
        self.lm_head = torch.nn.Linear(128, vocab_size)
    def forward(self, idx, targets=None):
        B,T=idx.shape
        token_embed = self.token_embedding_table(idx)
        positional_embed = self.positional_embedding(torch.arange(T))
        x = token_embed+positional_embed
        x= self.tf_blocks(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx[:, -self.block_size:])
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
class BigramLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size,block_size=8,token_embed_dim=16):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
        self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
        self.attention_head = MultiHeadAttention(n_embed=token_embed_dim,
                        timesteps=block_size,
                        head_size=token_embed_dim//4, # does head size have to == token embed_dim / n heads? I think it does
                        n_heads=4) # (in = (B, T, C), out = B,T,C)
        self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) # (in B, T, C, out = B, T, C, performs linear on C)
        self.block_size = block_size
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        token_embedding = self.token_embedding_table(idx) # (B,T, in), (B,T,embed_dim out)
        positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) # (T, embed_dim)
        x = token_embedding + positional_embedding # (B,T,embed_dim)
        x = self.attention_head(x) # (B,T,embed_dim)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx[:, -self.block_size:])
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
def main():
    ########################
    #PARAMS#################
    batch_size = 32
    block_size= 128
    n_embed = 128
    n_tf = 3
    n_heads=8
    head_size=16
    vocab_size=65
    ########################
    parser = argparse.ArgumentParser(
        description='Train a bigram language model'
    )
    parser.add_argument('-c', '--cont', action='store_true',)
    parser.add_argument('-e', '--eval', action='store_true',)
    parser.add_argument('-v', '--verbose',action='store_true')
    text = open('input.txt').read()
    characters = sorted(list(set(text)))
    decoder = dict(enumerate(characters))
    encoder = {v: k for k, v in decoder.items()}
    encode = lambda x: encoder[x]
    decode = lambda x: decoder[x]
    text_tensor = torch.tensor([encode(c) for c in text])
    train_tensor = text_tensor[:int(len(text_tensor) * 0.8)]
    val_tensor = text_tensor[int(len(text_tensor) * 0.8):]
    model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed)
    if parser.parse_args().verbose:
        print(model)
        num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print('parameters:', num_params)
    # if -c is passed we will load the model from the file
    if parser.parse_args().cont:
        state_dict = torch.load('transformer.pth')
        model.load_state_dict(state_dict)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
    s = time.time()
    if not parser.parse_args().eval:
        try:
            train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000)
        except KeyboardInterrupt:
            torch.save(model.state_dict(), 'transformer.pth')
            exit()
    if parser.parse_args().verbose:
        print('training time: ', time.time() - s)
    torch.save(model.state_dict(), 'transformer.pth')
    model.eval()
    print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]]))
    # 2.57 adam
if __name__ == '__main__':
    main()