Spaces:
Runtime error
Runtime error
File size: 8,835 Bytes
f5790af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import wget
import json
import os
import sentencepiece as spm
import re
CODEGEN_FOLDER = "./CodeGenModel"
CODEGEN_MODEL_NAME = "codegen-350M-multi"
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
CODEGEN_CONFIG = "config.json"
CODEGEN_VOCAB = "vocab.json"
CODEGEN_MERGES = "merges.txt"
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
CODEGEN_FILES_URLS = [
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
]
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
CODEGEN_SPM = "spm.model"
def ensure_codegen_files_exist():
os.makedirs(CODEGEN_FOLDER, exist_ok=True)
for url, filename in CODEGEN_FILES_URLS:
filepath = os.path.join(CODEGEN_FOLDER, filename)
if not os.path.exists(filepath):
wget.download(url, out=filepath)
filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM)
if not os.path.exists(filepath_spm):
wget.download(CODEGEN_SPM_URL, out=filepath_spm)
class CodeGenConfig:
def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_ctx = n_ctx
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
for key, value in kwargs.items():
setattr(self, key, value)
@classmethod
def from_dict(cls, config_dict):
return cls(**config_dict)
class CodeGenForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = CodeGenModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None):
transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask)
logits = self.lm_head(transformer_outputs)
return logits
class CodeGenModel(nn.Module):
def __init__(self, config):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_ids, attention_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
for block in self.h:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.ln_f(hidden_states)
return hidden_states.view(*output_shape)
class CodeGenBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = CodeGenAttention(config)
self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = CodeGenMLP(config)
def forward(self, hidden_states, attention_mask=None):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(hidden_states, attention_mask=attention_mask)
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feedforward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feedforward_hidden_states
return hidden_states
class CodeGenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, config.n_inner)
self.c_proj = nn.Linear(config.n_inner, config.n_embd)
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class CodeGenAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.n_head = config.n_head
self.embed_dim = config.n_embd
self.split_size = self.embed_dim
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.scale_attn_weights = config.scale_attn_weights
self.use_cache = config.use_cache
self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx)))
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / math.sqrt(value.size(-1))
mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)]
attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
return tensor.permute(0, 2, 1, 3)
def _merge_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(*new_shape)
def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head)
key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head)
value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head)
if past_key_value is not None:
past_key, past_value = past_key_value
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present_key_value = (key, value) if use_cache else None
attn_output = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present_key_value)
return outputs[0] |