ddgdgd / codegen_torch.py
Kfjjdjdjdhdhd's picture
Upload 13 files
f5790af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import wget
import json
import os
import sentencepiece as spm
import re
CODEGEN_FOLDER = "./CodeGenModel"
CODEGEN_MODEL_NAME = "codegen-350M-multi"
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
CODEGEN_CONFIG = "config.json"
CODEGEN_VOCAB = "vocab.json"
CODEGEN_MERGES = "merges.txt"
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
CODEGEN_FILES_URLS = [
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
]
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
CODEGEN_SPM = "spm.model"
def ensure_codegen_files_exist():
os.makedirs(CODEGEN_FOLDER, exist_ok=True)
for url, filename in CODEGEN_FILES_URLS:
filepath = os.path.join(CODEGEN_FOLDER, filename)
if not os.path.exists(filepath):
wget.download(url, out=filepath)
filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM)
if not os.path.exists(filepath_spm):
wget.download(CODEGEN_SPM_URL, out=filepath_spm)
class CodeGenConfig:
def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_ctx = n_ctx
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
for key, value in kwargs.items():
setattr(self, key, value)
@classmethod
def from_dict(cls, config_dict):
return cls(**config_dict)
class CodeGenForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = CodeGenModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None):
transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask)
logits = self.lm_head(transformer_outputs)
return logits
class CodeGenModel(nn.Module):
def __init__(self, config):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_ids, attention_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
for block in self.h:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.ln_f(hidden_states)
return hidden_states.view(*output_shape)
class CodeGenBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = CodeGenAttention(config)
self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = CodeGenMLP(config)
def forward(self, hidden_states, attention_mask=None):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(hidden_states, attention_mask=attention_mask)
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feedforward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feedforward_hidden_states
return hidden_states
class CodeGenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, config.n_inner)
self.c_proj = nn.Linear(config.n_inner, config.n_embd)
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class CodeGenAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.n_head = config.n_head
self.embed_dim = config.n_embd
self.split_size = self.embed_dim
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.scale_attn_weights = config.scale_attn_weights
self.use_cache = config.use_cache
self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx)))
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / math.sqrt(value.size(-1))
mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)]
attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
return tensor.permute(0, 2, 1, 3)
def _merge_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(*new_shape)
def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head)
key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head)
value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head)
if past_key_value is not None:
past_key, past_value = past_key_value
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present_key_value = (key, value) if use_cache else None
attn_output = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present_key_value)
return outputs[0]