import torch import torch.nn as nn import torch.nn.functional as F import wget import json import os import sentencepiece as spm import re CODEGEN_FOLDER = "./CodeGenModel" CODEGEN_MODEL_NAME = "codegen-350M-multi" CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin" CODEGEN_CONFIG = "config.json" CODEGEN_VOCAB = "vocab.json" CODEGEN_MERGES = "merges.txt" CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin" CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json" CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json" CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt" CODEGEN_FILES_URLS = [ (CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS), (CODEGEN_CONFIG_URL, CODEGEN_CONFIG), (CODEGEN_VOCAB_URL, CODEGEN_VOCAB), (CODEGEN_MERGES_URL, CODEGEN_MERGES), ] CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model" CODEGEN_SPM = "spm.model" def ensure_codegen_files_exist(): os.makedirs(CODEGEN_FOLDER, exist_ok=True) for url, filename in CODEGEN_FILES_URLS: filepath = os.path.join(CODEGEN_FOLDER, filename) if not os.path.exists(filepath): wget.download(url, out=filepath) filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM) if not os.path.exists(filepath_spm): wget.download(CODEGEN_SPM_URL, out=filepath_spm) class CodeGenConfig: def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs): self.vocab_size = vocab_size self.n_positions = n_positions self.n_ctx = n_ctx self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.n_inner = n_inner self.activation_function = activation_function self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id for key, value in kwargs.items(): setattr(self, key, value) @classmethod def from_dict(cls, config_dict): return cls(**config_dict) class CodeGenForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.transformer = CodeGenModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) def forward(self, input_ids, attention_mask=None): transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask) logits = self.lm_head(transformer_outputs) return logits class CodeGenModel(nn.Module): def __init__(self, config): super().__init__() self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) def forward(self, input_ids, attention_mask=None): input_shape = input_ids.size() input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) for block in self.h: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.ln_f(hidden_states) return hidden_states.view(*output_shape) class CodeGenBlock(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = CodeGenAttention(config) self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = CodeGenMLP(config) def forward(self, hidden_states, attention_mask=None): residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn(hidden_states, attention_mask=attention_mask) hidden_states = residual + attn_outputs residual = hidden_states hidden_states = self.ln_2(hidden_states) feedforward_hidden_states = self.mlp(hidden_states) hidden_states = residual + feedforward_hidden_states return hidden_states class CodeGenMLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, config.n_inner) self.c_proj = nn.Linear(config.n_inner, config.n_embd) self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) hidden_states = F.gelu(hidden_states) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class CodeGenAttention(nn.Module): def __init__(self, config): super().__init__() self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.n_head = config.n_head self.embed_dim = config.n_embd self.split_size = self.embed_dim self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) self.scale_attn_weights = config.scale_attn_weights self.use_cache = config.use_cache self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx))) def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: attn_weights = attn_weights / math.sqrt(value.size(-1)) mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)] attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device)) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = self.attn_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) return attn_output def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(*new_shape) return tensor.permute(0, 2, 1, 3) def _merge_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(*new_shape) def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False): query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head) key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head) value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head) if past_key_value is not None: past_key, past_value = past_key_value key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) present_key_value = (key, value) if use_cache else None attn_output = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present_key_value) return outputs[0]