|
import faiss
|
|
import torch
|
|
import logging
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
from transformers import AutoTokenizer
|
|
from bi.model import SharedBiEncoder
|
|
from bi.preprocess import preprocess_question
|
|
from cross_rerank.model import RerankerForInference
|
|
|
|
import itertools
|
|
from pyvi.ViTokenizer import tokenize
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
encoder: str = field(
|
|
default="vinai/phobert-base-v2",
|
|
metadata={'help': 'The encoder name or path.'}
|
|
)
|
|
tokenizer: str = field(
|
|
default=None,
|
|
metadata={'help': 'The encoder name or path.'}
|
|
)
|
|
cross_checkpoint: str = field(
|
|
default="vinai/phobert-base-v2",
|
|
metadata={'help': 'The encoder name or path.'}
|
|
)
|
|
cross_tokenizer: str = field(
|
|
default=None,
|
|
metadata={'help': 'The encoder name or path.'}
|
|
)
|
|
sentence_pooling_method: str = field(
|
|
default="cls",
|
|
metadata={'help': 'Embedding method'}
|
|
)
|
|
fp16: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Use fp16 in inference?'}
|
|
)
|
|
max_query_length: int = field(
|
|
default=32,
|
|
metadata={'help': 'Max query length.'}
|
|
)
|
|
max_passage_length: int = field(
|
|
default=256,
|
|
metadata={'help': 'Max passage length.'}
|
|
)
|
|
cross_max_length: int = field(
|
|
default=256,
|
|
metadata={'help': 'Max cross length.'}
|
|
)
|
|
cross_batch_size: int = field(
|
|
default=32,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=128,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
index_factory: str = field(
|
|
default="Flat",
|
|
metadata={'help': 'Faiss index factory.'}
|
|
)
|
|
k: int = field(
|
|
default=1000,
|
|
metadata={'help': 'How many neighbors to retrieve?'}
|
|
)
|
|
top_k: int = field(
|
|
default=1000,
|
|
metadata={'help': 'How many neighbors to rerank?'}
|
|
)
|
|
data_path: str = field(
|
|
default="/kaggle/input/zalo-data",
|
|
metadata={'help': 'Path to zalo data.'}
|
|
)
|
|
data_type: str = field(
|
|
default="test",
|
|
metadata={'help': 'Type data to test'}
|
|
)
|
|
corpus_file: str = field(
|
|
default="/kaggle/input/zalo-data",
|
|
metadata={'help': 'Path to zalo corpus.'}
|
|
)
|
|
|
|
data_file: str = field(
|
|
default=None,
|
|
metadata={'help': 'Path to evaluated data.'}
|
|
)
|
|
|
|
bi_data: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Data for bi-encoder training'}
|
|
)
|
|
|
|
save_embedding: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Save embeddings in memmap at save_dir?'}
|
|
)
|
|
load_embedding: str = field(
|
|
default='',
|
|
metadata={'help': 'Path to saved embeddings.'}
|
|
)
|
|
|
|
save_path: str = field(
|
|
default="embeddings.memmap",
|
|
metadata={'help': 'Path to save embeddings.'}
|
|
)
|
|
|
|
def index(model: SharedBiEncoder, tokenizer:AutoTokenizer, corpus, batch_size: int = 16, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
|
|
"""
|
|
1. Encode the entire corpus into dense embeddings;
|
|
2. Create faiss index;
|
|
3. Optionally save embeddings.
|
|
"""
|
|
if load_embedding != '':
|
|
test_tokens = tokenizer(['test'],
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=128,
|
|
return_tensors="pt").to('cuda')
|
|
test = model.encoder.get_representation(test_tokens['input_ids'], test_tokens['attention_mask'])
|
|
test = test.cpu().numpy()
|
|
dtype = test.dtype
|
|
dim = test.shape[-1]
|
|
|
|
all_embeddings = np.memmap(
|
|
load_embedding,
|
|
mode="r",
|
|
dtype=dtype
|
|
).reshape(-1, dim)
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(corpus), batch_size), desc="Inference Embeddings",
|
|
disable=len(corpus) < batch_size):
|
|
passages_batch = corpus[start_index:start_index + batch_size]
|
|
d_collated = tokenizer(
|
|
passages_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
return_tensors="pt",
|
|
).to('cuda')
|
|
|
|
with torch.no_grad():
|
|
corpus_embeddings = model.encoder.get_representation(d_collated['input_ids'], d_collated['attention_mask'])
|
|
|
|
corpus_embeddings = corpus_embeddings.cpu().numpy()
|
|
all_embeddings.append(corpus_embeddings)
|
|
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
dim = all_embeddings.shape[-1]
|
|
|
|
if save_embedding:
|
|
logger.info(f"saving embeddings at {save_path}...")
|
|
memmap = np.memmap(
|
|
save_path,
|
|
shape=all_embeddings.shape,
|
|
mode="w+",
|
|
dtype=all_embeddings.dtype
|
|
)
|
|
|
|
length = all_embeddings.shape[0]
|
|
|
|
save_batch_size = 10000
|
|
if length > save_batch_size:
|
|
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
|
j = min(i + save_batch_size, length)
|
|
memmap[i: j] = all_embeddings[i: j]
|
|
else:
|
|
memmap[:] = all_embeddings
|
|
|
|
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
if True:
|
|
co = faiss.GpuClonerOptions()
|
|
|
|
|
|
faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
|
|
|
|
|
|
|
|
logger.info("Adding embeddings...")
|
|
all_embeddings = all_embeddings.astype(np.float32)
|
|
|
|
faiss_index.train(all_embeddings)
|
|
faiss_index.add(all_embeddings)
|
|
return faiss_index
|
|
|
|
|
|
def search(model: SharedBiEncoder, tokenizer:AutoTokenizer, questions, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=128):
|
|
"""
|
|
1. Encode queries into dense embeddings;
|
|
2. Search through faiss index
|
|
"""
|
|
|
|
q_embeddings = []
|
|
|
|
|
|
for start_index in tqdm(range(0, len(questions), batch_size), desc="Inference Embeddings",
|
|
disable=len(questions) < batch_size):
|
|
|
|
q_collated = tokenizer(
|
|
questions[start_index: start_index + batch_size],
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=128,
|
|
return_tensors="pt",
|
|
).to('cuda')
|
|
|
|
with torch.no_grad():
|
|
query_embeddings = model.encoder.get_representation(q_collated['input_ids'], q_collated['attention_mask'])
|
|
query_embeddings = query_embeddings.cpu().numpy()
|
|
q_embeddings.append(query_embeddings)
|
|
|
|
q_embeddings = np.concatenate(q_embeddings, axis=0)
|
|
query_size = q_embeddings.shape[0]
|
|
all_scores = []
|
|
all_indices = []
|
|
|
|
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
|
|
j = min(i + batch_size, query_size)
|
|
q_embedding = q_embeddings[i: j]
|
|
score, indice = faiss_index.search(q_embedding.astype(np.float32), k=k)
|
|
all_scores.append(score)
|
|
all_indices.append(indice)
|
|
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_indices = np.concatenate(all_indices, axis=0)
|
|
return all_scores, all_indices
|
|
|
|
def rerank(reranker: SharedBiEncoder, tokenizer:AutoTokenizer, questions, corpus, retrieved_ids, batch_size = 128, max_length = 256, top_k=30):
|
|
eos = tokenizer.eos_token
|
|
|
|
texts = []
|
|
for idx in range(len(questions)):
|
|
for j in range(top_k):
|
|
texts.append(questions[idx] + eos + eos + corpus[retrieved_ids[idx][j]])
|
|
reranked_ids = []
|
|
rerank_scores = []
|
|
|
|
for start_index in tqdm(range(0, len(questions), batch_size), desc="Rerank",
|
|
disable=len(questions) < batch_size):
|
|
batch_retrieved_ids = retrieved_ids[start_index: start_index+batch_size]
|
|
collated = tokenizer(
|
|
texts[start_index*top_k: (start_index + batch_size)*top_k],
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
return_tensors="pt",
|
|
).to('cuda')
|
|
reranked_scores = reranker(collated).logits
|
|
reranked_scores = reranked_scores.view(-1,top_k).to('cpu').tolist()
|
|
for m in range(len(reranked_scores)):
|
|
tuple_lst = [(batch_retrieved_ids[m][n], reranked_scores[m][n]) for n in range(top_k)]
|
|
tuple_lst.sort(key=lambda tup: tup[1], reverse=True)
|
|
reranked_ids.append([tup[0] for tup in tuple_lst])
|
|
rerank_scores.append([tup[1] for tup in tuple_lst])
|
|
|
|
return reranked_ids, rerank_scores
|
|
|
|
|
|
|
|
|
|
def evaluate(preds, labels, cutoffs=[1,5,10,30,100]):
|
|
"""
|
|
Evaluate MRR and Recall at cutoffs.
|
|
"""
|
|
metrics = {}
|
|
|
|
|
|
mrrs = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
jump = False
|
|
for i, x in enumerate(pred, 1):
|
|
if x in label:
|
|
for k, cutoff in enumerate(cutoffs):
|
|
if i <= cutoff:
|
|
mrrs[k] += 1 / i
|
|
jump = True
|
|
if jump:
|
|
break
|
|
mrrs /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
mrr = mrrs[i]
|
|
metrics[f"MRR@{cutoff}"] = mrr
|
|
|
|
|
|
recalls = np.zeros(len(cutoffs))
|
|
for pred, label in zip(preds, labels):
|
|
for k, cutoff in enumerate(cutoffs):
|
|
recall = np.intersect1d(label, pred[:cutoff])
|
|
recalls[k] += len(recall) / len(label)
|
|
recalls /= len(preds)
|
|
for i, cutoff in enumerate(cutoffs):
|
|
recall = recalls[i]
|
|
metrics[f"Recall@{cutoff}"] = recall
|
|
|
|
return metrics
|
|
|
|
def calculate_score(ground_ids, retrieved_list):
|
|
all_count = 0
|
|
hit_count = 0
|
|
for i in range(len(ground_ids)):
|
|
all_check = True
|
|
hit_check = False
|
|
retrieved_ids = retrieved_list[i]
|
|
ans_ids = ground_ids[i]
|
|
for a_ids in ans_ids:
|
|
com = [a_id for a_id in a_ids if a_id in retrieved_ids]
|
|
if len(com) > 0:
|
|
hit_check = True
|
|
else:
|
|
all_check = False
|
|
|
|
if hit_check:
|
|
hit_count += 1
|
|
if all_check:
|
|
all_count += 1
|
|
|
|
all_acc = all_count/len(ground_ids)
|
|
hit_acc = hit_count/len(ground_ids)
|
|
return hit_acc, all_acc
|
|
|
|
def check(ground_ids, retrieved_list, cutoffs=[1,5,10,30,100]):
|
|
metrics = {}
|
|
for cutoff in cutoffs:
|
|
retrieved_k = [x[:cutoff] for x in retrieved_list]
|
|
hit_acc, all_acc = calculate_score(ground_ids, retrieved_k)
|
|
metrics[f"All@{cutoff}"] = all_acc
|
|
metrics[f"Hit@{cutoff}"] = hit_acc
|
|
return metrics
|
|
|
|
def save_bi_data(tokenized_queries, ground_ids, indices, scores, file, org_questions=None):
|
|
rst = []
|
|
|
|
for i in range(len(tokenized_queries)):
|
|
scores_i = scores[i]
|
|
indices_i = indices[i]
|
|
ans_ids = ground_ids[i]
|
|
all_ans_id = [element for x in ans_ids for element in x]
|
|
neg_doc_ids = []
|
|
neg_scores = []
|
|
for count in range(len(indices_i)):
|
|
if indices_i[count] not in all_ans_id and indices_i[count] != -1:
|
|
neg_doc_ids.append(indices_i[count])
|
|
neg_scores.append(scores_i[count])
|
|
|
|
for j in range(len(ans_ids)):
|
|
ans_id = ans_ids[j]
|
|
item = {}
|
|
if org_questions != None:
|
|
item['question'] = org_questions[i]
|
|
item['query'] = tokenized_queries[i]
|
|
item['positives'] = {}
|
|
item['negatives'] = {}
|
|
item['positives']['doc_id'] = []
|
|
item['positives']['score'] = []
|
|
item['negatives']['doc_id'] = neg_doc_ids
|
|
item['negatives']['score'] = neg_scores
|
|
for pos_id in ans_id:
|
|
item['positives']['doc_id'].append(pos_id)
|
|
try:
|
|
idx = indices_i.index(pos_id)
|
|
item['positives']['score'].append(scores_i[idx])
|
|
except:
|
|
item['positives']['score'].append(scores_i[-1])
|
|
|
|
rst.append(item)
|
|
|
|
with open(f'{file}.jsonl', 'w') as jsonl_file:
|
|
for item in rst:
|
|
json_line = json.dumps(item, ensure_ascii=False)
|
|
jsonl_file.write(json_line + '\n')
|
|
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
print(args)
|
|
model = SharedBiEncoder(model_checkpoint=args.encoder,
|
|
representation=args.sentence_pooling_method,
|
|
fixed=True)
|
|
model.to('cuda')
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer if args.tokenizer else args.encoder)
|
|
reranker = RerankerForInference(model_checkpoint=args.cross_checkpoint)
|
|
reranker.to('cuda')
|
|
reranker_tokenizer = AutoTokenizer.from_pretrained(args.cross_tokenizer if args.cross_tokenizer else args.cross_checkpoint)
|
|
csv_file = True
|
|
if args.data_file:
|
|
if args.data_file.endswith("jsonl"):
|
|
test_data = []
|
|
with open(args.data_file, 'r') as jsonl_file:
|
|
for line in jsonl_file:
|
|
temp = json.loads(line)
|
|
test_data.append(temp)
|
|
csv_file=False
|
|
elif args.data_file.endswith("json"):
|
|
csv_file=False
|
|
with open(args.data_file, 'r') as json_file:
|
|
test_data = json.load(json_file)
|
|
elif args.data_file.endswith("csv"):
|
|
test_data = pd.read_csv(args.data_file)
|
|
|
|
elif args.data_type == 'eval':
|
|
test_data = pd.read_csv(args.data_path + "/tval.csv")
|
|
elif args.data_type == 'train':
|
|
test_data = pd.read_csv(args.data_path + "/ttrain.csv")
|
|
elif args.data_type == 'all':
|
|
data1 = pd.read_csv(args.data_path + "/ttrain.csv")
|
|
data2 = pd.read_csv(args.data_path + "/ttest.csv")
|
|
data3 = pd.read_csv(args.data_path + "/tval.csv")
|
|
test_data = pd.concat([data1, data3, data2], ignore_index=True)
|
|
|
|
else:
|
|
test_data = pd.read_csv(args.data_path + "/ttest.csv")
|
|
corpus_data = pd.read_csv(args.corpus_file)
|
|
|
|
|
|
|
|
corpus = corpus_data['tokenized_text'].tolist()
|
|
|
|
if csv_file:
|
|
ans_ids = []
|
|
ground_ids = []
|
|
org_questions = test_data['question'].tolist()
|
|
questions = test_data['tokenized_question'].tolist()
|
|
for i in range(len(test_data)):
|
|
ans_ids.append(json.loads(test_data['best_ans_id'][i]))
|
|
ground_ids.append(json.loads(test_data['ans_id'][i]))
|
|
ground_truths = []
|
|
for sample in ans_ids:
|
|
temp = [corpus_data['law_id'][y] + "_" + str(corpus_data['article_id'][y]) for y in sample]
|
|
ground_truths.append(temp)
|
|
else:
|
|
ground_truths = []
|
|
ground_ids = []
|
|
org_questions = [sample['question'] for sample in test_data]
|
|
questions = [tokenize(preprocess_question(sample['question'], remove_end_phrase=False)) for sample in test_data]
|
|
for sample in test_data:
|
|
try:
|
|
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevance_articles']]
|
|
tempp = [it['ans_id'] for it in sample['relevance_articles']]
|
|
except:
|
|
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevant_articles']]
|
|
tempp = [it['ans_id'] for it in sample['relevant_articles']]
|
|
ground_truths.append(temp)
|
|
ground_ids.append(tempp)
|
|
|
|
faiss_index = index(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
corpus=corpus,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_passage_length,
|
|
index_factory=args.index_factory,
|
|
save_path=args.save_path,
|
|
save_embedding=args.save_embedding,
|
|
load_embedding=args.load_embedding
|
|
)
|
|
|
|
scores, indices = search(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
questions=questions,
|
|
faiss_index=faiss_index,
|
|
k=args.k,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_query_length
|
|
)
|
|
|
|
|
|
retrieval_results, retrieval_ids = [], []
|
|
for indice in indices:
|
|
|
|
indice = indice[indice != -1].tolist()
|
|
rst = []
|
|
for x in indice:
|
|
temp = corpus_data['law_id'][x] + "_" + str(corpus_data['article_id'][x])
|
|
if temp not in rst:
|
|
rst.append(temp)
|
|
retrieval_results.append(rst)
|
|
retrieval_ids.append(indice)
|
|
|
|
rerank_ids, rerank_scores = rerank(reranker, reranker_tokenizer, questions, corpus, retrieval_ids, args.cross_batch_size, args.cross_max_length, args.top_k)
|
|
|
|
if args.bi_data:
|
|
save_bi_data(questions, ground_ids, rerank_ids, rerank_scores, args.data_type, org_questions)
|
|
|
|
metrics = check(ground_ids, retrieval_ids)
|
|
print(metrics)
|
|
metrics = evaluate(retrieval_results, ground_truths)
|
|
print(metrics)
|
|
metrics = check(ground_ids, rerank_ids, cutoffs=[1,5,10,30])
|
|
print(metrics)
|
|
|
|
if __name__ == "__main__":
|
|
main() |