drama-base / modeling_drama.py
Tom Aarsen
Integrate Sentence Transformers, prevent manual tokenizer EOS
history blame
6.08 kB
from __future__ import annotations
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
class DramaModel(LlamaModel):
DramaModel is a modified version of the LlamaModel that supports bi-directional attention
and provides query and document encoding functionalities.
def __init__(self, config: LlamaConfig):
Initializes the DramaModel by disabling causal masking in self-attention layers.
for layer in self.layers:
layer.self_attn.is_causal = False
# query prefix
self.query_prefix = "Query: "
self.max_seq_len = 8192
self.hidden_size = config.hidden_size
def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
Updates the causal mask for attention computations.
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if attention_mask is None or attention_mask.dim() == 4:
return attention_mask
return AttentionMaskConverter._expand_mask(
def _average_pool(
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
Computes the average pooled representation of the last hidden states.
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def _tokenize(
tokenizer: PreTrainedTokenizer,
texts: list[str],
max_seq_len: int = None,
Tokenizes input text sequences with optional sequence length restriction.
if max_seq_len is None:
max_seq_len = self.max_seq_len
tokenized = tokenizer(
return tokenized
def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
Pass through the model and compute normalized embeddings.
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask tensor.
dim (int): Dimensionality for output embeddings.
torch.Tensor: Normalized output embeddings.
outputs = self.forward(
input_ids, attention_mask, *args, **kwargs
embeddings = self._average_pool(
outputs.last_hidden_state[:, :, :dim], attention_mask
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def encode_queries(
tokenizer: PreTrainedTokenizer,
queries: list[str],
max_seq_len: int = None,
dim: int = None,
Encodes a list of queries into embeddings.
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
queries (list[str]): List of query texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
if not queries:
raise ValueError("queries must not be empty.")
if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
raise ValueError("queries must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
queries = [self.query_prefix + query for query in queries]
tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
embeddings = self.encode(**tokenized_queries, dim=dim)
return embeddings
def encode_documents(
tokenizer: PreTrainedTokenizer,
documents: list[str],
max_seq_len: int = None,
dim: int = None,
Encodes a list of documents into embeddings.
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
documents (list[str]): List of document texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
if not documents:
raise ValueError("documents must not be empty.")
if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
raise ValueError("documents must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
embeddings = self.encode(**tokenized_documents, dim=dim)
return embeddings