File size: 6,336 Bytes
f96a150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import os.path as osp
import random
import sys
import argparse
import pandas as pd
import torch
from tqdm import tqdm
from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings
sys.path.append('.')
from stark_qa import load_skb, load_qa
from stark_qa.tools.api import get_api_embeddings
from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
from models.model import get_embeddings
import argparse
def parse_args():
parser = argparse.ArgumentParser()
# Dataset and embedding model selection
parser.add_argument('--dataset', default='prime', choices=['amazon', 'prime', 'mag'])
parser.add_argument('--emb_model', default='contriever',
choices=[
'text-embedding-ada-002',
'text-embedding-3-small',
'text-embedding-3-large',
'voyage-large-2-instruct',
'GritLM/GritLM-7B',
'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp',
'all-mpnet-base-v2' # for sentence transformer
]
)
# Mode settings
parser.add_argument('--mode', default='query', choices=['doc', 'query'])
# Path settings
parser.add_argument("--data_dir", default="data/", type=str)
parser.add_argument("--emb_dir", default="emb/", type=str)
# Text settings
parser.add_argument('--add_rel', action='store_true', default=False, help='add relation to the text')
parser.add_argument('--compact', action='store_true', default=False, help='make the text compact when input to the model')
# Evaluation settings
parser.add_argument("--human_generated_eval", action="store_true", help="if mode is `query`, then generating query embeddings on human generated evaluation split")
# Batch and node settings
parser.add_argument("--batch_size", default=1024, type=int)
# encode kwargs
parser.add_argument("--n_max_nodes", default=None, type=int, metavar="ENCODE")
parser.add_argument("--device", default=None, type=str, metavar="ENCODE")
parser.add_argument("--peft_model_name", default=None, type=str, help="llm2vec pdft model", metavar="ENCODE")
parser.add_argument("--instruction", type=str, help="gritl/llm2vec instruction", metavar="ENCODE")
args = parser.parse_args()
# Create encode_kwargs based on the custom metavar "ENCODE"
encode_kwargs = {k: v for k, v in vars(args).items() if v is not None and parser._option_string_actions[f'--{k}'].metavar == "ENCODE"}
return args, encode_kwargs
if __name__ == '__main__':
args, encode_kwargs = parse_args()
args.human_generated_eval = False
mode_surfix = '_human_generated_eval' if args.human_generated_eval and args.mode == 'query' else ''
mode_surfix += '_no_rel' if not args.add_rel else ''
mode_surfix += '_no_compact' if not args.compact else ''
emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f'{args.mode}{mode_surfix}')
csv_cache = osp.join(args.data_dir, args.dataset, f'{args.mode}{mode_surfix}.csv')
print(f'Embedding directory: {emb_dir}')
os.makedirs(emb_dir, exist_ok=True)
os.makedirs(os.path.dirname(csv_cache), exist_ok=True)
if args.mode == 'doc':
skb = load_skb(args.dataset)
lst = skb.candidate_ids
emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt')
if args.mode == 'query':
qa_dataset = load_qa(args.dataset, human_generated_eval=args.human_generated_eval)
lst = [qa_dataset[i][1] for i in range(len(qa_dataset))]
emb_path = osp.join(emb_dir, f'query_emb_dict.pt')
random.shuffle(lst)
# Load existing embeddings if they exist
if osp.exists(emb_path):
emb_dict = torch.load(emb_path)
exist_emb_indices = list(emb_dict.keys())
print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}')
else:
emb_dict = {}
exist_emb_indices = []
# Load existing document cache if it exists (only for doc mode)
if args.mode == 'doc' and osp.exists(csv_cache):
df = pd.read_csv(csv_cache)
cache_dict = dict(zip(df['index'], df['text']))
# Ensure that the indices in the cache match the expected indices
assert set(cache_dict.keys()) == set(lst), 'Indices in cache do not match the candidate indices.'
indices = list(set(lst) - set(exist_emb_indices))
texts = [cache_dict[idx] for idx in tqdm(indices, desc="Filtering docs for new embeddings")]
else:
indices = lst
texts = [qa_dataset.get_query_by_qid(idx) if args.mode == 'query'
else skb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact) for idx in tqdm(indices, desc="Gathering docs")]
if args.mode == 'doc':
df = pd.DataFrame({'index': indices, 'text': texts})
df.to_csv(csv_cache, index=False)
print(f'Generating embeddings for {len(texts)} texts...')
if args.emb_model == 'contriever':
encoder, tokenizer = get_contriever(dataset_name=args.dataset)
for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
batch_texts = texts[i:i+args.batch_size]
batch_embs = get_contriever_embeddings(batch_texts, encoder=encoder, tokenizer=tokenizer, device='cuda')
batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
batch_indices = indices[i:i+args.batch_size]
for idx, emb in zip(batch_indices, batch_embs):
emb_dict[idx] = emb.view(1, -1)
else:
for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
batch_texts = texts[i:i+args.batch_size]
batch_embs = get_embeddings(batch_texts, args.emb_model, **encode_kwargs)
batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
batch_indices = indices[i:i+args.batch_size]
for idx, emb in zip(batch_indices, batch_embs):
emb_dict[idx] = emb.view(1, -1)
torch.save(emb_dict, emb_path)
print(f'Saved {len(emb_dict)} embeddings to {emb_path}!')
|