from io import BytesIO from typing import Any, Dict, Optional, List import torch from PIL import Image from transformers import AutoProcessor, MllamaForConditionalGeneration from sentence_transformers.models import Transformer as BaseTransformer class MultiModalTransformer(BaseTransformer): def __init__( self, model_name_or_path: str, cache_dir: Optional[str] = None, tokenizer_args: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__(model_name_or_path, **kwargs) if tokenizer_args is None: tokenizer_args = {} # Initialize processor self.processor = AutoProcessor.from_pretrained( model_name_or_path, cache_dir=cache_dir, **tokenizer_args ) def _load_model( self, model_name_or_path: str, config, cache_dir: str, backend: str, is_peft_model: bool, **model_args, ) -> None: self.auto_model = MllamaForConditionalGeneration.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args ) def forward( self, features: Dict[str, torch.Tensor], **kwargs ) -> Dict[str, torch.Tensor]: # Process inputs through the model outputs = self.auto_model( **features, return_dict=True, output_hidden_states=True, **kwargs ) # Apply last pooling and normalization last_hidden_state = outputs.hidden_states[-1] attention_mask = features["attention_mask"] sentence_embedding = self._last_pooling(last_hidden_state, attention_mask) features.update({"sentence_embedding": sentence_embedding}) return features def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Apply last token pooling and L2 normalization""" sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_state.shape[0] reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] return torch.nn.functional.normalize(reps, p=2, dim=-1) def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]: def process_text_item(item): if isinstance(item, str): return item, [] text, images = "", [] for sub_item in item: if sub_item["type"] == "text": text += sub_item["content"] elif sub_item["type"] in ["image_bytes", "image_path"]: text += "<|image|>" if sub_item["type"] == "image_bytes": img = Image.open(BytesIO(sub_item["content"])).convert("RGB") else: img = Image.open(sub_item["content"]).convert("RGB") images.append(img) else: raise ValueError(f"Unknown data type {sub_item['type']}") return text, images all_texts, all_images = [], [] for item in texts: text, images = process_text_item(item) all_texts.append(text) all_images.extend(images) # Process inputs through the processor if all_images: inputs = self.processor( text=all_texts, images=all_images, padding="longest", truncation=True, max_length=self.max_seq_length, return_tensors="pt" ) else: inputs = self.processor( text=all_texts, padding="longest", truncation=True, max_length=self.max_seq_length, return_tensors="pt" ) return inputs