vlm-o / model /language /language_model.py
veerpareek's picture
Upload 35 files
577d9ca verified
import torch
from torch import nn
from typing import Optional
from .language_config import LanguageModelConfig
from .language_components import DecoderLayer, RMSNorm, KVCache
class LanguageModel(nn.Module):
def __init__(self, config: LanguageModelConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self):
return self.embed_tokens
# Ignore copy
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> torch.FloatTensor:
hidden_states = inputs_embeds
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
kv_cache=kv_cache,
)
hidden_states = self.norm(hidden_states)
return hidden_states