vlm-o / model /multimodal /multimodal_components.py
veerpareek's picture
Upload 35 files
577d9ca verified
import torch
from torch import nn
from typing import Optional, Tuple
from .multimodal_config import MultiModalConfig
from ..utils.kv_cache import KVCache
from ..language.language_model import LanguageModel
class CausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = LanguageModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def get_input_embeddings(self):
return self.model.embed_tokens
def tie_weights(self):
self.lm_head.weight = self.model.embed_tokens.weight
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs
logits = self.lm_head(hidden_states)
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
return_data["kv_cache"] = kv_cache
return return_data
class MultiModalProjector(nn.Module):
def __init__(self, config: MultiModalConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
hidden_states = self.linear(image_features)
return hidden_states