llm2vec4cxr / modeling_llm2vec4cxr.py
lukeingawesome's picture
Add custom model class with proper latent attention architecture
c589499 verified
"""
Custom model class for LLM2Vec4CXR that properly handles latent attention pooling.
"""
from llm2vec.models.bidirectional_llama import LlamaBiModel
from llm2vec.pooling import LatentAttentionPooling
import torch
import torch.nn as nn
class LLM2Vec4CXRModel(LlamaBiModel):
"""
Custom LlamaBiModel that includes latent attention pooling by default.
This prevents the warning about unused latent attention weights.
"""
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
# Initialize latent attention pooling
self.latent_attn = LatentAttentionPooling(
d_model=config.hidden_size,
num_heads=8, # Standard for this model size
num_latents=512 # Standard for LLM2Vec
)
# Move to the same device/dtype as the base model
if hasattr(self, 'model') and hasattr(self.model, 'embed_tokens'):
device = self.model.embed_tokens.weight.device
dtype = self.model.embed_tokens.weight.dtype
self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
"""
Forward pass that properly handles latent attention pooling.
"""
# Get base model output
outputs = super().forward(input_ids, attention_mask=attention_mask, **kwargs)
# If we have latent attention pooling, apply it
if hasattr(self, 'latent_attn') and self.latent_attn is not None:
if embed_mask is not None:
# Use embed_mask for instruction-following tasks
pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
else:
# Use attention_mask for simple encoding
pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
return pooled_output
return outputs.last_hidden_state
# Register the model for auto loading
from transformers import AutoModel
AutoModel.register(LLM2Vec4CXRModel.__name__, LLM2Vec4CXRModel)