lukeingawesome commited on
Commit
c589499
·
verified ·
1 Parent(s): e2bbf2a

Add custom model class with proper latent attention architecture

Browse files
Files changed (1) hide show
  1. modeling_llm2vec4cxr.py +55 -0
modeling_llm2vec4cxr.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom model class for LLM2Vec4CXR that properly handles latent attention pooling.
3
+ """
4
+
5
+ from llm2vec.models.bidirectional_llama import LlamaBiModel
6
+ from llm2vec.pooling import LatentAttentionPooling
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class LLM2Vec4CXRModel(LlamaBiModel):
12
+ """
13
+ Custom LlamaBiModel that includes latent attention pooling by default.
14
+ This prevents the warning about unused latent attention weights.
15
+ """
16
+
17
+ def __init__(self, config, **kwargs):
18
+ super().__init__(config, **kwargs)
19
+
20
+ # Initialize latent attention pooling
21
+ self.latent_attn = LatentAttentionPooling(
22
+ d_model=config.hidden_size,
23
+ num_heads=8, # Standard for this model size
24
+ num_latents=512 # Standard for LLM2Vec
25
+ )
26
+
27
+ # Move to the same device/dtype as the base model
28
+ if hasattr(self, 'model') and hasattr(self.model, 'embed_tokens'):
29
+ device = self.model.embed_tokens.weight.device
30
+ dtype = self.model.embed_tokens.weight.dtype
31
+ self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
32
+
33
+ def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
34
+ """
35
+ Forward pass that properly handles latent attention pooling.
36
+ """
37
+ # Get base model output
38
+ outputs = super().forward(input_ids, attention_mask=attention_mask, **kwargs)
39
+
40
+ # If we have latent attention pooling, apply it
41
+ if hasattr(self, 'latent_attn') and self.latent_attn is not None:
42
+ if embed_mask is not None:
43
+ # Use embed_mask for instruction-following tasks
44
+ pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
45
+ else:
46
+ # Use attention_mask for simple encoding
47
+ pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
48
+ return pooled_output
49
+
50
+ return outputs.last_hidden_state
51
+
52
+
53
+ # Register the model for auto loading
54
+ from transformers import AutoModel
55
+ AutoModel.register(LLM2Vec4CXRModel.__name__, LLM2Vec4CXRModel)