Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from typing import Optional, Tuple | |
from .multimodal_config import MultiModalConfig | |
from ..utils.kv_cache import KVCache | |
from ..language.language_model import LanguageModel | |
class CausalLM(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.model = LanguageModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def tie_weights(self): | |
self.lm_head.weight = self.model.embed_tokens.weight | |
def forward( | |
self, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
kv_cache: Optional[KVCache] = None, | |
) -> Tuple: | |
outputs = self.model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
inputs_embeds=inputs_embeds, | |
kv_cache=kv_cache, | |
) | |
hidden_states = outputs | |
logits = self.lm_head(hidden_states) | |
logits = logits.float() | |
return_data = { | |
"logits": logits, | |
} | |
if kv_cache is not None: | |
return_data["kv_cache"] = kv_cache | |
return return_data | |
class MultiModalProjector(nn.Module): | |
def __init__(self, config: MultiModalConfig): | |
super().__init__() | |
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) | |
def forward(self, image_features): | |
hidden_states = self.linear(image_features) | |
return hidden_states | |