Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from typing import Optional | |
from .language_config import LanguageModelConfig | |
from .language_components import DecoderLayer, RMSNorm, KVCache | |
class LanguageModel(nn.Module): | |
def __init__(self, config: LanguageModelConfig): | |
super().__init__() | |
self.config = config | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
self.layers = nn.ModuleList( | |
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
) | |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
def get_input_embeddings(self): | |
return self.embed_tokens | |
# Ignore copy | |
def forward( | |
self, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
kv_cache: Optional[KVCache] = None, | |
) -> torch.FloatTensor: | |
hidden_states = inputs_embeds | |
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) | |
hidden_states = hidden_states * normalizer | |
for decoder_layer in self.layers: | |
hidden_states = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
kv_cache=kv_cache, | |
) | |
hidden_states = self.norm(hidden_states) | |
return hidden_states |