|
|
|
import torch |
|
import argparse |
|
from torch.nn import functional as F |
|
import time |
|
from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock |
|
torch.manual_seed(1337) |
|
|
|
def get_batch(batch_size, dataset, block_size): |
|
sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1)) |
|
xb = torch.zeros(batch_size,block_size, dtype=torch.long) |
|
yb = torch.zeros(batch_size,block_size, dtype=torch.long) |
|
for idx, sample_index in enumerate(sample): |
|
xb[idx,:] = dataset[sample_index:sample_index+block_size] |
|
yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1] |
|
return xb, yb |
|
|
|
@torch.no_grad() |
|
def eval(model, batch_size, block_size, dataset): |
|
xb, yb = get_batch(batch_size, dataset, block_size) |
|
logits, loss = model(xb, yb) |
|
return loss.item() |
|
|
|
def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps): |
|
sumloss = 0 |
|
for _ in range(1,steps+1): |
|
xb, yb = get_batch(batch_size, train_ds, block_size) |
|
logits, loss = model(xb, yb) |
|
sumloss += loss.item() |
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
if _ % 1000 == 0: |
|
val_loss = eval(model, 30, block_size, val_ds,) |
|
print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}") |
|
|
|
sumloss = 0 |
|
|
|
class Transformer(torch.nn.Module): |
|
def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None: |
|
super().__init__() |
|
self.block_size=block_size |
|
self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) |
|
self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) |
|
self.tf_blocks = torch.nn.Sequential( |
|
*[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)] |
|
) |
|
self.lm_head = torch.nn.Linear(128, vocab_size) |
|
def forward(self, idx, targets=None): |
|
B,T=idx.shape |
|
token_embed = self.token_embedding_table(idx) |
|
positional_embed = self.positional_embedding(torch.arange(T)) |
|
x = token_embed+positional_embed |
|
x= self.tf_blocks(x) |
|
logits = self.lm_head(x) |
|
|
|
if targets is None: |
|
loss = None |
|
else: |
|
B, T, C = logits.shape |
|
logits = logits.view(B*T, C) |
|
targets = targets.view(B*T) |
|
loss = F.cross_entropy(logits, targets) |
|
return logits, loss |
|
def generate(self, idx, max_new_tokens): |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits, loss = self(idx[:, -self.block_size:]) |
|
|
|
logits = logits[:, -1, :] |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
return idx |
|
class BigramLanguageModel(torch.nn.Module): |
|
def __init__(self, vocab_size,block_size=8,token_embed_dim=16): |
|
super().__init__() |
|
self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) |
|
self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) |
|
self.attention_head = MultiHeadAttention(n_embed=token_embed_dim, |
|
timesteps=block_size, |
|
head_size=token_embed_dim//4, |
|
n_heads=4) |
|
self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) |
|
self.block_size = block_size |
|
def forward(self, idx, targets=None): |
|
B, T = idx.shape |
|
|
|
token_embedding = self.token_embedding_table(idx) |
|
positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) |
|
x = token_embedding + positional_embedding |
|
x = self.attention_head(x) |
|
logits = self.lm_head(x) |
|
if targets is None: |
|
loss = None |
|
else: |
|
B, T, C = logits.shape |
|
logits = logits.view(B*T, C) |
|
targets = targets.view(B*T) |
|
loss = F.cross_entropy(logits, targets) |
|
return logits, loss |
|
|
|
def generate(self, idx, max_new_tokens): |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
logits, loss = self(idx[:, -self.block_size:]) |
|
|
|
logits = logits[:, -1, :] |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
return idx |
|
def main(): |
|
|
|
|
|
batch_size = 32 |
|
block_size= 128 |
|
n_embed = 128 |
|
n_tf = 3 |
|
n_heads=8 |
|
head_size=16 |
|
vocab_size=65 |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Train a bigram language model' |
|
) |
|
parser.add_argument('-c', '--cont', action='store_true',) |
|
parser.add_argument('-e', '--eval', action='store_true',) |
|
parser.add_argument('-v', '--verbose',action='store_true') |
|
text = open('input.txt').read() |
|
characters = sorted(list(set(text))) |
|
decoder = dict(enumerate(characters)) |
|
encoder = {v: k for k, v in decoder.items()} |
|
encode = lambda x: encoder[x] |
|
decode = lambda x: decoder[x] |
|
text_tensor = torch.tensor([encode(c) for c in text]) |
|
train_tensor = text_tensor[:int(len(text_tensor) * 0.8)] |
|
val_tensor = text_tensor[int(len(text_tensor) * 0.8):] |
|
model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed) |
|
if parser.parse_args().verbose: |
|
print(model) |
|
num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print('parameters:', num_params) |
|
|
|
if parser.parse_args().cont: |
|
state_dict = torch.load('transformer.pth') |
|
model.load_state_dict(state_dict) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5) |
|
s = time.time() |
|
if not parser.parse_args().eval: |
|
try: |
|
train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000) |
|
except KeyboardInterrupt: |
|
torch.save(model.state_dict(), 'transformer.pth') |
|
exit() |
|
if parser.parse_args().verbose: |
|
print('training time: ', time.time() - s) |
|
torch.save(model.state_dict(), 'transformer.pth') |
|
model.eval() |
|
print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]])) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|