from io import BytesIO from typing import Any, Dict, Optional, List import torch from PIL import Image from transformers import AutoProcessor, MllamaForConditionalGeneration from sentence_transformers.models import Transformer as BaseTransformer class MultiModalTransformer(BaseTransformer): def __init__( self, model_name_or_path: str, cache_dir: Optional[str] = None, tokenizer_args: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__(model_name_or_path, **kwargs) if tokenizer_args is None: tokenizer_args = {} tokenizer_args.pop("trust_remote_code", None) # Initialize processor self.processor = AutoProcessor.from_pretrained( model_name_or_path, cache_dir=cache_dir, **tokenizer_args ) def _load_model( self, model_name_or_path: str, config, cache_dir: str, backend: str, is_peft_model: bool, **model_args, ) -> None: model_args.pop("trust_remote_code", None) self.auto_model = MllamaForConditionalGeneration.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args ) def forward( self, features: Dict[str, torch.Tensor], **kwargs ) -> Dict[str, torch.Tensor]: # Process inputs through the model outputs = self.auto_model( **features, return_dict=True, output_hidden_states=True, **kwargs ) features.update({"token_embeddings": outputs.hidden_states[-1]}) return features def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]: def process_text_item(item): if isinstance(item, str): return item, None text, img = "", None if "image" in item: text += "<|image|>" img = item["image"] if isinstance(img, bytes): img = Image.open(BytesIO(img)).convert("RGB") elif isinstance(img, str): img = Image.open(img).convert("RGB") elif not isinstance(img, Image): raise ValueError(f"Unknown image type {type(img)}") if "text" in item: if text: text += "<|begin_of_text|> " text += item["text"].lstrip() return text, img all_texts, all_images = [], [] for item in texts: text, images = process_text_item(item) all_texts.append(text) all_images.append(images) if all_images != [None] * len(all_images): inputs = self.processor( text=all_texts, images=all_images, padding="longest", truncation=True, max_length=self.max_seq_length, return_tensors="pt" ) else: inputs = self.processor( text=all_texts, padding="longest", truncation=True, max_length=self.max_seq_length, return_tensors="pt" ) return inputs