|
import torch |
|
import numpy as np |
|
import sys |
|
import os |
|
|
|
from .utils import freeze |
|
|
|
|
|
class BaseEmbedder: |
|
def __init__(self, conf): |
|
self.checkpoint_path = conf.text_embedder.params.checkpoint_path |
|
self.tokenizer_path = conf.text_embedder.params.tokenizer_path |
|
self.max_length = conf.text_embedder.tokens_lenght |
|
self.llm = None |
|
|
|
def to(self, device='cpu', dtype=torch.float32): |
|
self.llm = self.llm.to(device=device, dtype=dtype) |
|
return self |
|
|
|
def freeze(self): |
|
self.llm = freeze(self.llm) |
|
return self |
|
|
|
def compile(self): |
|
self.llm = torch.compile(self.llm) |
|
return self |
|
|
|
|
|
class EmbedderWithTokenizer(BaseEmbedder): |
|
|
|
def __init__(self, conf): |
|
super().__init__(conf) |
|
self.tokenizer = None |
|
|
|
def tokenize(self, text): |
|
model_input = self.tokenizer( |
|
text, |
|
max_length=self.max_length, |
|
truncation=True, |
|
add_special_tokens=True, |
|
padding='max_length', |
|
return_tensors='pt' |
|
) |
|
return model_input.input_ids.to(self.llm.device) |
|
|
|
def __call__(self, text): |
|
return self.llm(self.tokenize(text), output_hidden_states=True)[0] |
|
|
|
|
|
class T5TextEmbedder(EmbedderWithTokenizer): |
|
|
|
def __init__(self, conf): |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
super().__init__(conf) |
|
|
|
self.llm = T5EncoderModel.from_pretrained(self.checkpoint_path) |
|
self.tokenizer = T5Tokenizer.from_pretrained(self.tokenizer_path, clean_up_tokenization_spaces=False) |
|
|
|
|
|
def get_text_embedder(conf): |
|
return T5TextEmbedder(conf) |
|
|