from llm2vec import LLM2Vec from peft import PeftModel from transformers import ( AutoConfig, PretrainedConfig, AutoTokenizer, ) import torch import logging import json import os logger = logging.getLogger(__name__) class LLM2VecWrapper(LLM2Vec): def __init__(self, *args, **kwargs): super(LLM2VecWrapper, self).__init__(*args, **kwargs) def to(self, device_or_dtype): """Override to method to ensure all modules are properly moved.""" result = super().to(device_or_dtype) # Ensure latent attention pooling is also moved if hasattr(result, 'latent_attn') and result.latent_attn is not None: result.latent_attn = result.latent_attn.to(device_or_dtype) return result def prepare_for_tokenization(self, text): text = ( "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" ) return text def encode_text(self, text, max_length=None): """ Encode text to embeddings with proper embed_mask handling. Args: text (str or list): Text(s) to encode max_length (int, optional): Maximum sequence length Returns: torch.Tensor: Text embeddings """ if max_length is None: max_length = getattr(self, 'max_length', 512) inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=max_length ) # Add embed_mask (same as attention_mask for simple text encoding) inputs["embed_mask"] = inputs["attention_mask"].clone() # Move to same device as model import torch model_device = next(self.parameters()).device inputs = {k: v.to(model_device) for k, v in inputs.items()} with torch.no_grad(): embeddings = self(inputs) return embeddings def tokenize_with_separator(self, texts, max_length=None, separator='!@#$%^&*()'): """ Tokenize texts with special handling for separator-based splitting. This is useful for instruction-following tasks. Args: texts (list): List of texts to tokenize max_length (int, optional): Maximum sequence length separator (str): Separator to split instruction from text Returns: dict: Tokenized inputs with attention masks and embed masks """ if max_length is None: max_length = getattr(self, 'max_length', 512) texts_2 = [] original_texts = [] for text in texts: parts = text.split(separator) texts_2.append(parts[1] if len(parts) > 1 else "") original_texts.append("".join(parts)) # Tokenize original texts tokenized = self.tokenizer( original_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) # Create embedding masks for the separated parts import torch embed_mask = None for t_i, t in enumerate(texts_2): ids = self.tokenizer( [t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False, ) e_m = torch.zeros_like(tokenized["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) if embed_mask is None: embed_mask = e_m.unsqueeze(0) else: embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) tokenized["embed_mask"] = embed_mask return tokenized def encode_with_instruction(self, texts, max_length=None, separator='!@#$%^&*()'): """ Encode texts with instruction-following using separator-based processing. Args: texts (list): List of texts with instructions separated by separator max_length (int, optional): Maximum sequence length separator (str): Separator between instruction and text Returns: torch.Tensor: Text embeddings """ tokenized = self.tokenize_with_separator(texts, max_length, separator) # Move to same device as model import torch model_device = next(self.parameters()).device tokenized = {k: v.to(model_device) for k, v in tokenized.items()} with torch.no_grad(): embeddings = self(tokenized) return embeddings def encode_with_separator(self, texts, device=None, max_length=None, separator='!@#$%^&*()'): """ Encode texts with special separator-based handling for instruction/text pairs. Args: texts (list): List of texts to encode (with separator for instruction/text pairs) device: Device to run on (auto-detect if None) max_length: Maximum sequence length (use model default if None) separator: Separator string for instruction/text pairs Returns: torch.Tensor: Embeddings for the texts """ if device is None: device = next(self.parameters()).device if max_length is None: max_length = 512 # Ensure model is on the right device self = self.to(device) # Process texts with separator texts_2 = [] original_texts = [] for text in texts: parts = text.split(separator) texts_2.append(parts[1] if len(parts) > 1 else "") original_texts.append("".join(parts)) # Tokenize original texts tokenized = self.tokenizer( original_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) # Create embedding masks embed_mask = None for t_i, t in enumerate(texts_2): ids = self.tokenizer( [t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False, ) e_m = torch.zeros_like(tokenized["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) if embed_mask is None: embed_mask = e_m.unsqueeze(0) else: embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) tokenized["embed_mask"] = embed_mask # Move to device and compute embeddings tokenized = {k: v.to(device) for k, v in tokenized.items()} tokenized = {k: v.to(self.model.dtype) if v.dtype.is_floating_point else v for k, v in tokenized.items()} with torch.no_grad(): embeddings = self(tokenized) return embeddings def compute_similarities(self, query_text, candidate_texts, device=None, separator='!@#$%^&*()'): """ Compute similarity scores between a query text and candidate texts. Args: query_text (str): The query text (with separator for instruction/text pairs) candidate_texts (list): List of candidate texts to compare against device: Device to run on (auto-detect if None) separator: Separator string for instruction/text pairs Returns: torch.Tensor: Similarity scores for each candidate """ import torch.nn.functional as F if device is None: device = next(self.parameters()).device # Combine query and candidates all_texts = [query_text] + candidate_texts # Get embeddings embeddings = self.encode_with_separator(all_texts, device=device, separator=separator) # Compute similarities between query (first embedding) and candidates similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1) return similarities def _load_latent_attention_weights(self, model_path, use_safetensors=True): """ Automatically load latent attention weights from model files. Args: model_path: Path to model (local directory or HuggingFace repo) use_safetensors: Whether to use safetensors format """ import os if os.path.isdir(model_path): # Local directory - try pytorch_model.bin first pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") if os.path.exists(pytorch_model_path): print(f"Loading latent attention weights from {pytorch_model_path}") try: import torch state_dict = torch.load(pytorch_model_path, weights_only=True) latent_attn_weights = {k: v for k, v in state_dict.items() if k.startswith('latent_attn.')} if latent_attn_weights: missing_keys, unexpected_keys = self.latent_attn.load_state_dict( {k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()}, strict=False ) if not missing_keys and not unexpected_keys: print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights") else: print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}") else: print("⚠️ No latent attention weights found in the model file") except Exception as e: print(f"❌ Error loading latent attention weights: {e}") else: # HuggingFace repository - load from safetensors if use_safetensors: print("Loading latent attention weights from HuggingFace safetensors...") try: from safetensors.torch import load_file from huggingface_hub import hf_hub_download # Download the safetensors file safetensors_path = hf_hub_download(repo_id=model_path, filename="model.safetensors") # Load weights from safetensors safetensors_weights = load_file(safetensors_path) # Extract latent attention weights latent_attn_weights = {k: v for k, v in safetensors_weights.items() if k.startswith('latent_attn.')} if latent_attn_weights: print(f"Found {len(latent_attn_weights)} latent attention weights in safetensors") # Load the weights into the latent attention module missing_keys, unexpected_keys = self.latent_attn.load_state_dict( {k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()}, strict=False ) if not missing_keys and not unexpected_keys: print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights from safetensors") else: print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}") else: print("⚠️ No latent attention weights found in safetensors file") except Exception as e: print(f"❌ Error loading latent attention weights from safetensors: {e}") @classmethod def from_pretrained( cls, base_model_name_or_path, peft_model_name_or_path=None, merge_peft=False, enable_bidirectional=True, extra_model_name_or_path=None, **kwargs, ): # pop out encoder args keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] encoder_args = { key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None } tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" config = AutoConfig.from_pretrained(base_model_name_or_path) config_class_name = config.__class__.__name__ model_class = cls._get_model_class( config_class_name, enable_bidirectional=enable_bidirectional ) model = model_class.from_pretrained(base_model_name_or_path, **kwargs) if os.path.isdir(base_model_name_or_path) and os.path.exists( f"{base_model_name_or_path}/config.json" ): with open(f"{base_model_name_or_path}/config.json", "r") as fIn: config_dict = json.load(fIn) config = PretrainedConfig.from_dict(config_dict) model.config._name_or_path = config._name_or_path # For special case where config.json and adapter weights are in the same directory if hasattr(model, "peft_config"): model = PeftModel.from_pretrained( model, base_model_name_or_path, ) model = model.merge_and_unload() if peft_model_name_or_path is not None: model = PeftModel.from_pretrained( model, peft_model_name_or_path, ) if merge_peft: model = model.merge_and_unload() if extra_model_name_or_path is not None: logger.info(f"Loading extra model from {extra_model_name_or_path}") if not merge_peft: model = model.merge_and_unload() if isinstance(extra_model_name_or_path, str): model = PeftModel.from_pretrained( model, extra_model_name_or_path, ) model = model.merge_and_unload() elif isinstance(extra_model_name_or_path, list): for extra_model in extra_model_name_or_path: model = PeftModel.from_pretrained( model, extra_model, ) peft_model_name_or_path = extra_model model = model.merge_and_unload() else: raise ValueError( f"extra_model_name_or_path should be a string or a list of strings." ) config = {} config_addr = ( peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path ) if os.path.exists(f"{config_addr}/llm2vec_config.json"): with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: llm2vec_config = json.load(fIn) config.update(llm2vec_config) for key, value in encoder_args.items(): config[key] = value llm2vec_model = cls(model=model, tokenizer=tokenizer, **config) # Auto-load latent attention weights if using latent_attention pooling if (hasattr(llm2vec_model, 'latent_attn') and llm2vec_model.latent_attn is not None and llm2vec_model.pooling_mode == "latent_attention"): llm2vec_model._load_latent_attention_weights(base_model_name_or_path, kwargs.get('use_safetensors', True)) # Ensure the entire model is converted to the requested dtype if 'torch_dtype' in kwargs and kwargs['torch_dtype'] is not None: llm2vec_model = llm2vec_model.to(kwargs['torch_dtype']) return llm2vec_model