|
from pathlib import Path |
|
import re |
|
import torch |
|
import openai |
|
from functools import partial |
|
import time |
|
import multiprocessing |
|
from sentence_transformers import SentenceTransformer |
|
from concurrent.futures import ProcessPoolExecutor |
|
from torch.multiprocessing import set_start_method |
|
import os |
|
from transformers import AutoModel, AutoTokenizer, AutoConfig |
|
from Reranking.rerankers.node import Contriever |
|
import torch.nn as nn |
|
|
|
|
|
MODEL_CACHE = {} |
|
|
|
|
|
def get_openai_embedding(text: str, |
|
model: str = "text-embedding-ada-002", |
|
max_retry: int = 10, |
|
sleep_time: int = 0) -> torch.FloatTensor: |
|
""" |
|
Get the OpenAI embedding for a given text. |
|
|
|
Args: |
|
text (str): The input text to be embedded. |
|
model (str): The model to use for embedding. Default is "text-embedding-ada-002". |
|
max_retry (int): Maximum number of retries in case of an error. Default is 1. |
|
sleep_time (int): Sleep time between retries in seconds. Default is 0. |
|
|
|
Returns: |
|
torch.FloatTensor: The embedding of the input text. |
|
""" |
|
assert isinstance(text, str), f'text must be str, but got {type(text)}' |
|
assert len(text) > 0, 'text to be embedded should be non-empty' |
|
|
|
client = openai.OpenAI() |
|
|
|
for _ in range(max_retry): |
|
try: |
|
emb = client.embeddings.create(input=[text], model=model) |
|
return torch.FloatTensor(emb.data[0].embedding).view(1, -1) |
|
except openai.BadRequestError as e: |
|
print(f'{e}') |
|
e = str(e) |
|
ori_length = len(text.split(' ')) |
|
match = re.search(r'maximum context length is (\d+) tokens, however you requested (\d+) tokens', e) |
|
if match is not None: |
|
max_length = int(match.group(1)) |
|
cur_length = int(match.group(2)) |
|
ratio = float(max_length) / cur_length |
|
for reduce_rate in range(9, 0, -1): |
|
shorten_text = text.split(' ') |
|
length = int(ratio * ori_length * (reduce_rate * 0.1)) |
|
shorten_text = ' '.join(shorten_text[:length]) |
|
try: |
|
emb = client.embeddings.create(input=[shorten_text], model=model) |
|
print(f'length={length} works! reduce_rate={0.1 * reduce_rate}.') |
|
return torch.FloatTensor(emb.data[0].embedding).view(1, -1) |
|
except: |
|
continue |
|
except (openai.RateLimitError, openai.APITimeoutError) as e: |
|
print(f'{e}, sleep for {sleep_time} seconds') |
|
time.sleep(sleep_time) |
|
raise RuntimeError("Failed to get embedding after maximum retries") |
|
|
|
|
|
def get_openai_embeddings(texts: list, |
|
model: str = "text-embedding-ada-002", |
|
n_max_nodes: int = 5) -> torch.FloatTensor: |
|
""" |
|
Get embeddings for a list of texts using OpenAI's embedding model. |
|
|
|
Args: |
|
texts (list): List of input texts to be embedded. |
|
n_max_nodes (int): Maximum number of parallel processes. Default is 5. |
|
model (str): The model to use for embedding. Default is "text-embedding-ada-002". |
|
|
|
Returns: |
|
torch.FloatTensor: A tensor containing embeddings for all input texts. |
|
""" |
|
assert isinstance(texts, list), f'texts must be list, but got {type(texts)}' |
|
assert all([len(s) > 0 for s in texts]), 'every string in the `texts` list to be embedded should be non-empty' |
|
|
|
processes = min(len(texts), n_max_nodes) |
|
ada_encoder = partial(get_openai_embedding, model=model) |
|
|
|
with multiprocessing.Pool(processes=processes) as pool: |
|
results = pool.map(ada_encoder, texts) |
|
|
|
results = torch.cat(results, dim=0) |
|
return results |
|
|
|
|
|
|
|
def process(text, encoder): |
|
|
|
embeddings = encoder.encode(text, convert_to_tensor=True) |
|
return embeddings |
|
|
|
|
|
def get_sentence_transformer_embeddings(texts: list, |
|
model_name: str = "all-mpnet-base-v2") -> torch.FloatTensor: |
|
|
|
assert isinstance(texts, list), f'texts must be list, but got {type(texts)}' |
|
assert all([len(s) > 0 for s in texts]), 'every string in the `texts` list to be embedded should be non-empty' |
|
|
|
|
|
if model_name not in MODEL_CACHE: |
|
|
|
encoder = SentenceTransformer(model_name) |
|
MODEL_CACHE[model_name] = encoder |
|
else: |
|
|
|
encoder = MODEL_CACHE[model_name] |
|
|
|
|
|
|
|
|
|
results = process(texts, encoder=encoder) |
|
|
|
results = results.cpu() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
return results |
|
|
|
|
|
def get_contriever(dataset_name) -> torch.FloatTensor: |
|
|
|
contriever_name = 'facebook/contriever' |
|
encoder = AutoModel.from_pretrained(contriever_name) |
|
tokenizer = AutoTokenizer.from_pretrained(contriever_name) |
|
emb_dim = 768 |
|
|
|
encoder = Contriever(encoder, emb_dim=emb_dim) |
|
current_file = Path(__file__).resolve() |
|
project_root = current_file.parents[3] |
|
checkpoint_path = f"{project_root}/Reasoning/text_retrievers/model_checkpoint/contriever/{dataset_name}/best_20250204.pth" |
|
checkpoint = torch.load(checkpoint_path) |
|
encoder.load_state_dict(checkpoint) |
|
encoder.to('cuda') |
|
encoder.eval() |
|
|
|
return encoder, tokenizer |
|
|
|
|
|
@torch.no_grad() |
|
def get_contriever_embeddings(texts, encoder, tokenizer, device) -> torch.FloatTensor: |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
|
|
text_enc = tokenizer(texts, padding='max_length', truncation=True, return_tensors='pt', max_length=512) |
|
|
|
input_ids = text_enc['input_ids'].to(device) |
|
attention_mask = text_enc['attention_mask'].to(device) |
|
token_type_ids = text_enc.get('token_type_ids', None) |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.to(device) |
|
embeddings = encoder.get_text_emb(input_ids, attention_mask, token_type_ids) |
|
|
|
embeddings = embeddings.cpu() |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return embeddings |