wxDai's picture
[Init]
eb339cb
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
class MldTextEncoder(nn.Module):
def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
super().__init__()
if 't5' in modelpath:
self.text_model = SentenceTransformer(modelpath)
self.tokenizer = self.text_model.tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
self.text_model = AutoModel.from_pretrained(modelpath)
self.max_length = self.tokenizer.model_max_length
if "clip" in modelpath:
self.text_encoded_dim = self.text_model.config.text_config.hidden_size
if last_hidden_state:
self.name = "clip_hidden"
else:
self.name = "clip"
elif "bert" in modelpath:
self.name = "bert"
self.text_encoded_dim = self.text_model.config.hidden_size
elif 't5' in modelpath:
self.name = 't5'
else:
raise ValueError(f"Model {modelpath} not supported")
def forward(self, texts: list[str]) -> torch.Tensor:
# get prompt text embeddings
if self.name in ["clip", "clip_hidden"]:
text_inputs = self.tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
# split into max length Clip can handle
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
elif self.name == "bert":
text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
if self.name == "clip":
# (batch_Size, text_encoded_dim)
text_embeddings = self.text_model.get_text_features(
text_input_ids.to(self.text_model.device))
# (batch_Size, 1, text_encoded_dim)
text_embeddings = text_embeddings.unsqueeze(1)
elif self.name == "clip_hidden":
# (batch_Size, seq_length , text_encoded_dim)
text_embeddings = self.text_model.text_model(
text_input_ids.to(self.text_model.device)).last_hidden_state
elif self.name == "bert":
# (batch_Size, seq_length , text_encoded_dim)
text_embeddings = self.text_model(
**text_inputs.to(self.text_model.device)).last_hidden_state
elif self.name == 't5':
text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
text_embeddings = text_embeddings.unsqueeze(1)
else:
raise NotImplementedError(f"Model {self.name} not implemented")
return text_embeddings