|  | """ | 
					
						
						|  | taken from: https://github.com/karpathy/minGPT/ | 
					
						
						|  | GPT model: | 
					
						
						|  | - the initial stem consists of a combination of token encoding and a positional encoding | 
					
						
						|  | - the meat of it is a uniform sequence of Transformer blocks | 
					
						
						|  | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block | 
					
						
						|  | - all blocks feed into a central residual pathway similar to resnets | 
					
						
						|  | - the final decoder is a linear projection into a vanilla Softmax classifier | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import logging | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | from transformers import top_k_top_p_filtering | 
					
						
						|  |  | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPTConfig: | 
					
						
						|  | """ base GPT config, params common to all GPT versions """ | 
					
						
						|  | embd_pdrop = 0.1 | 
					
						
						|  | resid_pdrop = 0.1 | 
					
						
						|  | attn_pdrop = 0.1 | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, vocab_size, block_size, **kwargs): | 
					
						
						|  | self.vocab_size = vocab_size | 
					
						
						|  | self.block_size = block_size | 
					
						
						|  | for k,v in kwargs.items(): | 
					
						
						|  | setattr(self, k, v) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT1Config(GPTConfig): | 
					
						
						|  | """ GPT-1 like network roughly 125M params """ | 
					
						
						|  | n_layer = 12 | 
					
						
						|  | n_head = 12 | 
					
						
						|  | n_embd = 768 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CausalSelfAttention(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A vanilla multi-head masked self-attention layer with a projection at the end. | 
					
						
						|  | It is possible to use torch.nn.MultiheadAttention here but I am including an | 
					
						
						|  | explicit implementation here to show that there is nothing too scary here. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert config.n_embd % config.n_head == 0 | 
					
						
						|  |  | 
					
						
						|  | self.key = nn.Linear(config.n_embd, config.n_embd) | 
					
						
						|  | self.query = nn.Linear(config.n_embd, config.n_embd) | 
					
						
						|  | self.value = nn.Linear(config.n_embd, config.n_embd) | 
					
						
						|  |  | 
					
						
						|  | self.attn_drop = nn.Dropout(config.attn_pdrop) | 
					
						
						|  | self.resid_drop = nn.Dropout(config.resid_pdrop) | 
					
						
						|  |  | 
					
						
						|  | self.proj = nn.Linear(config.n_embd, config.n_embd) | 
					
						
						|  |  | 
					
						
						|  | mask = torch.tril(torch.ones(config.block_size, | 
					
						
						|  | config.block_size)) | 
					
						
						|  | if hasattr(config, "n_unmasked"): | 
					
						
						|  | mask[:config.n_unmasked, :config.n_unmasked] = 1 | 
					
						
						|  | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) | 
					
						
						|  | self.n_head = config.n_head | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, layer_past=None): | 
					
						
						|  | B, T, C = x.size() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | 
					
						
						|  | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | 
					
						
						|  | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | present = torch.stack((k, v)) | 
					
						
						|  | if layer_past is not None: | 
					
						
						|  | past_key, past_value = layer_past | 
					
						
						|  | k = torch.cat((past_key, k), dim=-2) | 
					
						
						|  | v = torch.cat((past_value, v), dim=-2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | 
					
						
						|  | if layer_past is None: | 
					
						
						|  | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) | 
					
						
						|  |  | 
					
						
						|  | att = F.softmax(att, dim=-1) | 
					
						
						|  | att = self.attn_drop(att) | 
					
						
						|  | y = att @ v | 
					
						
						|  | y = y.transpose(1, 2).contiguous().view(B, T, C) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | y = self.resid_drop(self.proj(y)) | 
					
						
						|  | return y, present | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Block(nn.Module): | 
					
						
						|  | """ an unassuming Transformer block """ | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.ln1 = nn.LayerNorm(config.n_embd) | 
					
						
						|  | self.ln2 = nn.LayerNorm(config.n_embd) | 
					
						
						|  | self.attn = CausalSelfAttention(config) | 
					
						
						|  | self.mlp = nn.Sequential( | 
					
						
						|  | nn.Linear(config.n_embd, 4 * config.n_embd), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(4 * config.n_embd, config.n_embd), | 
					
						
						|  | nn.Dropout(config.resid_pdrop), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, layer_past=None, return_present=False): | 
					
						
						|  |  | 
					
						
						|  | if return_present: assert not self.training | 
					
						
						|  |  | 
					
						
						|  | attn, present = self.attn(self.ln1(x), layer_past=layer_past) | 
					
						
						|  |  | 
					
						
						|  | x = x + attn | 
					
						
						|  | x = x + self.mlp(self.ln2(x)) | 
					
						
						|  | if layer_past is not None or return_present: | 
					
						
						|  | return x, present | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GPT(nn.Module): | 
					
						
						|  | """  the full GPT language model, with a context size of block_size """ | 
					
						
						|  | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, | 
					
						
						|  | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, | 
					
						
						|  | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, | 
					
						
						|  | n_layer=n_layer, n_head=n_head, n_embd=n_embd, | 
					
						
						|  | n_unmasked=n_unmasked) | 
					
						
						|  |  | 
					
						
						|  | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) | 
					
						
						|  | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) | 
					
						
						|  | self.drop = nn.Dropout(config.embd_pdrop) | 
					
						
						|  |  | 
					
						
						|  | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) | 
					
						
						|  |  | 
					
						
						|  | self.ln_f = nn.LayerNorm(config.n_embd) | 
					
						
						|  | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | 
					
						
						|  | self.block_size = config.block_size | 
					
						
						|  | self.apply(self._init_weights) | 
					
						
						|  | self.config = config | 
					
						
						|  | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) | 
					
						
						|  |  | 
					
						
						|  | def get_block_size(self): | 
					
						
						|  | return self.block_size | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, module): | 
					
						
						|  | if isinstance(module, (nn.Linear, nn.Embedding)): | 
					
						
						|  | module.weight.data.normal_(mean=0.0, std=0.02) | 
					
						
						|  | if isinstance(module, nn.Linear) and module.bias is not None: | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | elif isinstance(module, nn.LayerNorm): | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | module.weight.data.fill_(1.0) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, idx, embeddings=None, targets=None): | 
					
						
						|  |  | 
					
						
						|  | token_embeddings = self.tok_emb(idx) | 
					
						
						|  |  | 
					
						
						|  | if embeddings is not None: | 
					
						
						|  | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) | 
					
						
						|  |  | 
					
						
						|  | t = token_embeddings.shape[1] | 
					
						
						|  | assert t <= self.block_size, "Cannot forward, model block size is exhausted." | 
					
						
						|  | position_embeddings = self.pos_emb[:, :t, :] | 
					
						
						|  | x = self.drop(token_embeddings + position_embeddings) | 
					
						
						|  | x = self.blocks(x) | 
					
						
						|  | x = self.ln_f(x) | 
					
						
						|  | logits = self.head(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = None | 
					
						
						|  | if targets is not None: | 
					
						
						|  | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | 
					
						
						|  |  | 
					
						
						|  | return logits, loss | 
					
						
						|  |  | 
					
						
						|  | def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): | 
					
						
						|  |  | 
					
						
						|  | assert not self.training | 
					
						
						|  | token_embeddings = self.tok_emb(idx) | 
					
						
						|  | if embeddings is not None: | 
					
						
						|  | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) | 
					
						
						|  |  | 
					
						
						|  | if past is not None: | 
					
						
						|  | assert past_length is not None | 
					
						
						|  | past = torch.cat(past, dim=-2) | 
					
						
						|  | past_shape = list(past.shape) | 
					
						
						|  | expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] | 
					
						
						|  | assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" | 
					
						
						|  | position_embeddings = self.pos_emb[:, past_length, :] | 
					
						
						|  | else: | 
					
						
						|  | position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] | 
					
						
						|  |  | 
					
						
						|  | x = self.drop(token_embeddings + position_embeddings) | 
					
						
						|  | presents = [] | 
					
						
						|  | for i, block in enumerate(self.blocks): | 
					
						
						|  | x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) | 
					
						
						|  | presents.append(present) | 
					
						
						|  |  | 
					
						
						|  | x = self.ln_f(x) | 
					
						
						|  | logits = self.head(x) | 
					
						
						|  |  | 
					
						
						|  | loss = None | 
					
						
						|  | if targets is not None: | 
					
						
						|  | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | 
					
						
						|  |  | 
					
						
						|  | return logits, loss, torch.stack(presents) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DummyGPT(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, add_value=1): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.add_value = add_value | 
					
						
						|  |  | 
					
						
						|  | def forward(self, idx): | 
					
						
						|  | return idx + self.add_value, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CodeGPT(nn.Module): | 
					
						
						|  | """Takes in semi-embeddings""" | 
					
						
						|  | def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, | 
					
						
						|  | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, | 
					
						
						|  | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, | 
					
						
						|  | n_layer=n_layer, n_head=n_head, n_embd=n_embd, | 
					
						
						|  | n_unmasked=n_unmasked) | 
					
						
						|  |  | 
					
						
						|  | self.tok_emb = nn.Linear(in_channels, config.n_embd) | 
					
						
						|  | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) | 
					
						
						|  | self.drop = nn.Dropout(config.embd_pdrop) | 
					
						
						|  |  | 
					
						
						|  | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) | 
					
						
						|  |  | 
					
						
						|  | self.ln_f = nn.LayerNorm(config.n_embd) | 
					
						
						|  | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | 
					
						
						|  | self.block_size = config.block_size | 
					
						
						|  | self.apply(self._init_weights) | 
					
						
						|  | self.config = config | 
					
						
						|  | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) | 
					
						
						|  |  | 
					
						
						|  | def get_block_size(self): | 
					
						
						|  | return self.block_size | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, module): | 
					
						
						|  | if isinstance(module, (nn.Linear, nn.Embedding)): | 
					
						
						|  | module.weight.data.normal_(mean=0.0, std=0.02) | 
					
						
						|  | if isinstance(module, nn.Linear) and module.bias is not None: | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | elif isinstance(module, nn.LayerNorm): | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | module.weight.data.fill_(1.0) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, idx, embeddings=None, targets=None): | 
					
						
						|  |  | 
					
						
						|  | token_embeddings = self.tok_emb(idx) | 
					
						
						|  |  | 
					
						
						|  | if embeddings is not None: | 
					
						
						|  | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) | 
					
						
						|  |  | 
					
						
						|  | t = token_embeddings.shape[1] | 
					
						
						|  | assert t <= self.block_size, "Cannot forward, model block size is exhausted." | 
					
						
						|  | position_embeddings = self.pos_emb[:, :t, :] | 
					
						
						|  | x = self.drop(token_embeddings + position_embeddings) | 
					
						
						|  | x = self.blocks(x) | 
					
						
						|  | x = self.taming_cinln_f(x) | 
					
						
						|  | logits = self.head(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = None | 
					
						
						|  | if targets is not None: | 
					
						
						|  | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | 
					
						
						|  |  | 
					
						
						|  | return logits, loss | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def top_k_logits(logits, k): | 
					
						
						|  | v, ix = torch.topk(logits, k) | 
					
						
						|  | out = logits.clone() | 
					
						
						|  | out[out < v[:, [-1]]] = -float('Inf') | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): | 
					
						
						|  | """ | 
					
						
						|  | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in | 
					
						
						|  | the sequence, feeding the predictions back into the model each time. Clearly the sampling | 
					
						
						|  | has quadratic complexity unlike an RNN that is only linear, and has a finite context window | 
					
						
						|  | of block_size, unlike an RNN that has an infinite context window. | 
					
						
						|  | """ | 
					
						
						|  | block_size = model.get_block_size() | 
					
						
						|  | model.eval() | 
					
						
						|  | for k in range(steps): | 
					
						
						|  | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] | 
					
						
						|  | logits, _ = model(x_cond) | 
					
						
						|  |  | 
					
						
						|  | logits = logits[:, -1, :] / temperature | 
					
						
						|  |  | 
					
						
						|  | if top_k is not None: | 
					
						
						|  | logits = top_k_logits(logits, top_k) | 
					
						
						|  |  | 
					
						
						|  | probs = F.softmax(logits, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | if sample: | 
					
						
						|  | ix = torch.multinomial(probs, num_samples=1) | 
					
						
						|  | else: | 
					
						
						|  | _, ix = torch.topk(probs, k=1, dim=-1) | 
					
						
						|  |  | 
					
						
						|  | x = torch.cat((x, ix), dim=1) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def sample_with_past(x, model, steps, temperature=1., sample_logits=True, | 
					
						
						|  | top_k=None, top_p=None, callback=None): | 
					
						
						|  |  | 
					
						
						|  | sample = x | 
					
						
						|  | cond_len = x.shape[1] | 
					
						
						|  | past = None | 
					
						
						|  | for n in range(steps): | 
					
						
						|  | if callback is not None: | 
					
						
						|  | callback(n) | 
					
						
						|  | logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) | 
					
						
						|  | if past is None: | 
					
						
						|  | past = [present] | 
					
						
						|  | else: | 
					
						
						|  | past.append(present) | 
					
						
						|  | logits = logits[:, -1, :] / temperature | 
					
						
						|  | if top_k is not None: | 
					
						
						|  | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | 
					
						
						|  |  | 
					
						
						|  | probs = F.softmax(logits, dim=-1) | 
					
						
						|  | if not sample_logits: | 
					
						
						|  | _, x = torch.topk(probs, k=1, dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | x = torch.multinomial(probs, num_samples=1) | 
					
						
						|  |  | 
					
						
						|  | sample = torch.cat((sample, x), dim=1) | 
					
						
						|  | del past | 
					
						
						|  | sample = sample[:, cond_len:] | 
					
						
						|  | return sample | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class KMeans(nn.Module): | 
					
						
						|  | def __init__(self, ncluster=512, nc=3, niter=10): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.ncluster = ncluster | 
					
						
						|  | self.nc = nc | 
					
						
						|  | self.niter = niter | 
					
						
						|  | self.shape = (3,32,32) | 
					
						
						|  | self.register_buffer("C", torch.zeros(self.ncluster,nc)) | 
					
						
						|  | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) | 
					
						
						|  |  | 
					
						
						|  | def is_initialized(self): | 
					
						
						|  | return self.initialized.item() == 1 | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def initialize(self, x): | 
					
						
						|  | N, D = x.shape | 
					
						
						|  | assert D == self.nc, D | 
					
						
						|  | c = x[torch.randperm(N)[:self.ncluster]] | 
					
						
						|  | for i in range(self.niter): | 
					
						
						|  |  | 
					
						
						|  | a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) | 
					
						
						|  |  | 
					
						
						|  | c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) | 
					
						
						|  |  | 
					
						
						|  | nanix = torch.any(torch.isnan(c), dim=1) | 
					
						
						|  | ndead = nanix.sum().item() | 
					
						
						|  | print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) | 
					
						
						|  | c[nanix] = x[torch.randperm(N)[:ndead]] | 
					
						
						|  |  | 
					
						
						|  | self.C.copy_(c) | 
					
						
						|  | self.initialized.fill_(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, reverse=False, shape=None): | 
					
						
						|  | if not reverse: | 
					
						
						|  |  | 
					
						
						|  | bs,c,h,w = x.shape | 
					
						
						|  | assert c == self.nc | 
					
						
						|  | x = x.reshape(bs,c,h*w,1) | 
					
						
						|  | C = self.C.permute(1,0) | 
					
						
						|  | C = C.reshape(1,c,1,self.ncluster) | 
					
						
						|  | a = ((x-C)**2).sum(1).argmin(-1) | 
					
						
						|  | return a | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | bs, HW = x.shape | 
					
						
						|  | """ | 
					
						
						|  | c = self.C.reshape( 1, self.nc,  1, self.ncluster) | 
					
						
						|  | c = c[bs*[0],:,:,:] | 
					
						
						|  | c = c[:,:,HW*[0],:] | 
					
						
						|  | x =      x.reshape(bs,       1, HW,             1) | 
					
						
						|  | x = x[:,3*[0],:,:] | 
					
						
						|  | x = torch.gather(c, dim=3, index=x) | 
					
						
						|  | """ | 
					
						
						|  | x = self.C[x] | 
					
						
						|  | x = x.permute(0,2,1) | 
					
						
						|  | shape = shape if shape is not None else self.shape | 
					
						
						|  | x = x.reshape(bs, *shape) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  |