viLegal_cross / src /eval_cross.py
coang's picture
Upload 14 files
77010c2 verified
import faiss
import torch
import logging
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from transformers import AutoTokenizer
from bi.model import SharedBiEncoder
from bi.preprocess import preprocess_question
from cross_rerank.model import RerankerForInference
#from src.process import process_query, process_text, concat_str
import itertools
from pyvi.ViTokenizer import tokenize
logger = logging.getLogger(__name__)
@dataclass
class Args:
encoder: str = field(
default="vinai/phobert-base-v2",
metadata={'help': 'The encoder name or path.'}
)
tokenizer: str = field(
default=None,
metadata={'help': 'The encoder name or path.'}
)
cross_checkpoint: str = field(
default="vinai/phobert-base-v2",
metadata={'help': 'The encoder name or path.'}
)
cross_tokenizer: str = field(
default=None,
metadata={'help': 'The encoder name or path.'}
)
sentence_pooling_method: str = field(
default="cls",
metadata={'help': 'Embedding method'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
max_query_length: int = field(
default=32,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=256,
metadata={'help': 'Max passage length.'}
)
cross_max_length: int = field(
default=256,
metadata={'help': 'Max cross length.'}
)
cross_batch_size: int = field(
default=32,
metadata={'help': 'Inference batch size.'}
)
batch_size: int = field(
default=128,
metadata={'help': 'Inference batch size.'}
)
index_factory: str = field(
default="Flat",
metadata={'help': 'Faiss index factory.'}
)
k: int = field(
default=1000,
metadata={'help': 'How many neighbors to retrieve?'}
)
top_k: int = field(
default=1000,
metadata={'help': 'How many neighbors to rerank?'}
)
data_path: str = field(
default="/kaggle/input/zalo-data",
metadata={'help': 'Path to zalo data.'}
)
data_type: str = field(
default="test",
metadata={'help': 'Type data to test'}
)
corpus_file: str = field(
default="/kaggle/input/zalo-data",
metadata={'help': 'Path to zalo corpus.'}
)
data_file: str = field(
default=None,
metadata={'help': 'Path to evaluated data.'}
)
bi_data: bool = field(
default=False,
metadata={'help': 'Data for bi-encoder training'}
)
save_embedding: bool = field(
default=False,
metadata={'help': 'Save embeddings in memmap at save_dir?'}
)
load_embedding: str = field(
default='',
metadata={'help': 'Path to saved embeddings.'}
)
save_path: str = field(
default="embeddings.memmap",
metadata={'help': 'Path to save embeddings.'}
)
def index(model: SharedBiEncoder, tokenizer:AutoTokenizer, corpus, batch_size: int = 16, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
"""
1. Encode the entire corpus into dense embeddings;
2. Create faiss index;
3. Optionally save embeddings.
"""
if load_embedding != '':
test_tokens = tokenizer(['test'],
padding=True,
truncation=True,
max_length=128,
return_tensors="pt").to('cuda')
test = model.encoder.get_representation(test_tokens['input_ids'], test_tokens['attention_mask'])
test = test.cpu().numpy()
dtype = test.dtype
dim = test.shape[-1]
all_embeddings = np.memmap(
load_embedding,
mode="r",
dtype=dtype
).reshape(-1, dim)
else:
#df_corpus = pd.DataFrame()
#df_corpus['text'] = corpus
#pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
#df_corpus['processed_text'] = df_corpus['text'].parallel_apply(process_text)
#processed_corpus = df_corpus['processed_text'].tolist()
#model.to('cuda')
all_embeddings = []
for start_index in tqdm(range(0, len(corpus), batch_size), desc="Inference Embeddings",
disable=len(corpus) < batch_size):
passages_batch = corpus[start_index:start_index + batch_size]
d_collated = tokenizer(
passages_batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to('cuda')
with torch.no_grad():
corpus_embeddings = model.encoder.get_representation(d_collated['input_ids'], d_collated['attention_mask'])
corpus_embeddings = corpus_embeddings.cpu().numpy()
all_embeddings.append(corpus_embeddings)
all_embeddings = np.concatenate(all_embeddings, axis=0)
dim = all_embeddings.shape[-1]
if save_embedding:
logger.info(f"saving embeddings at {save_path}...")
memmap = np.memmap(
save_path,
shape=all_embeddings.shape,
mode="w+",
dtype=all_embeddings.dtype
)
length = all_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = all_embeddings[i: j]
else:
memmap[:] = all_embeddings
# create faiss index
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
#if model.device == torch.device("cuda"):
if True:
co = faiss.GpuClonerOptions()
#co = faiss.GpuMultipleClonerOptions()
#co.useFloat16 = True
faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
#faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
# NOTE: faiss only accepts float32
logger.info("Adding embeddings...")
all_embeddings = all_embeddings.astype(np.float32)
#print(all_embeddings[0])
faiss_index.train(all_embeddings)
faiss_index.add(all_embeddings)
return faiss_index
def search(model: SharedBiEncoder, tokenizer:AutoTokenizer, questions, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=128):
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
"""
#model.to('cuda')
q_embeddings = []
#questions = queries['tokenized_question'].tolist()
#questions = [process_query(x) for x in questions]
for start_index in tqdm(range(0, len(questions), batch_size), desc="Inference Embeddings",
disable=len(questions) < batch_size):
q_collated = tokenizer(
questions[start_index: start_index + batch_size],
padding=True,
truncation=True,
max_length=128,
return_tensors="pt",
).to('cuda')
with torch.no_grad():
query_embeddings = model.encoder.get_representation(q_collated['input_ids'], q_collated['attention_mask'])
query_embeddings = query_embeddings.cpu().numpy()
q_embeddings.append(query_embeddings)
q_embeddings = np.concatenate(q_embeddings, axis=0)
query_size = q_embeddings.shape[0]
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
j = min(i + batch_size, query_size)
q_embedding = q_embeddings[i: j]
score, indice = faiss_index.search(q_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def rerank(reranker: SharedBiEncoder, tokenizer:AutoTokenizer, questions, corpus, retrieved_ids, batch_size = 128, max_length = 256, top_k=30):
eos = tokenizer.eos_token
#questions = queries['tokenized_question'].tolist()
texts = []
for idx in range(len(questions)):
for j in range(top_k):
texts.append(questions[idx] + eos + eos + corpus[retrieved_ids[idx][j]])
reranked_ids = []
rerank_scores = []
for start_index in tqdm(range(0, len(questions), batch_size), desc="Rerank",
disable=len(questions) < batch_size):
batch_retrieved_ids = retrieved_ids[start_index: start_index+batch_size]
collated = tokenizer(
texts[start_index*top_k: (start_index + batch_size)*top_k],
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to('cuda')
reranked_scores = reranker(collated).logits
reranked_scores = reranked_scores.view(-1,top_k).to('cpu').tolist()
for m in range(len(reranked_scores)):
tuple_lst = [(batch_retrieved_ids[m][n], reranked_scores[m][n]) for n in range(top_k)]
tuple_lst.sort(key=lambda tup: tup[1], reverse=True)
reranked_ids.append([tup[0] for tup in tuple_lst])
rerank_scores.append([tup[1] for tup in tuple_lst])
return reranked_ids, rerank_scores
def evaluate(preds, labels, cutoffs=[1,5,10,30,100]):
"""
Evaluate MRR and Recall at cutoffs.
"""
metrics = {}
# MRR
mrrs = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
jump = False
for i, x in enumerate(pred, 1):
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
mrrs[k] += 1 / i
jump = True
if jump:
break
mrrs /= len(preds)
for i, cutoff in enumerate(cutoffs):
mrr = mrrs[i]
metrics[f"MRR@{cutoff}"] = mrr
# Recall
recalls = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
for k, cutoff in enumerate(cutoffs):
recall = np.intersect1d(label, pred[:cutoff])
recalls[k] += len(recall) / len(label)
recalls /= len(preds)
for i, cutoff in enumerate(cutoffs):
recall = recalls[i]
metrics[f"Recall@{cutoff}"] = recall
return metrics
def calculate_score(ground_ids, retrieved_list):
all_count = 0
hit_count = 0
for i in range(len(ground_ids)):
all_check = True
hit_check = False
retrieved_ids = retrieved_list[i]
ans_ids = ground_ids[i]
for a_ids in ans_ids:
com = [a_id for a_id in a_ids if a_id in retrieved_ids]
if len(com) > 0:
hit_check = True
else:
all_check = False
if hit_check:
hit_count += 1
if all_check:
all_count += 1
all_acc = all_count/len(ground_ids)
hit_acc = hit_count/len(ground_ids)
return hit_acc, all_acc
def check(ground_ids, retrieved_list, cutoffs=[1,5,10,30,100]):
metrics = {}
for cutoff in cutoffs:
retrieved_k = [x[:cutoff] for x in retrieved_list]
hit_acc, all_acc = calculate_score(ground_ids, retrieved_k)
metrics[f"All@{cutoff}"] = all_acc
metrics[f"Hit@{cutoff}"] = hit_acc
return metrics
def save_bi_data(tokenized_queries, ground_ids, indices, scores, file, org_questions=None):
rst = []
#tokenized_queries = test_data['tokenized_question'].tolist()
for i in range(len(tokenized_queries)):
scores_i = scores[i]
indices_i = indices[i]
ans_ids = ground_ids[i]
all_ans_id = [element for x in ans_ids for element in x]
neg_doc_ids = []
neg_scores = []
for count in range(len(indices_i)):
if indices_i[count] not in all_ans_id and indices_i[count] != -1:
neg_doc_ids.append(indices_i[count])
neg_scores.append(scores_i[count])
for j in range(len(ans_ids)):
ans_id = ans_ids[j]
item = {}
if org_questions != None:
item['question'] = org_questions[i]
item['query'] = tokenized_queries[i]
item['positives'] = {}
item['negatives'] = {}
item['positives']['doc_id'] = []
item['positives']['score'] = []
item['negatives']['doc_id'] = neg_doc_ids
item['negatives']['score'] = neg_scores
for pos_id in ans_id:
item['positives']['doc_id'].append(pos_id)
try:
idx = indices_i.index(pos_id)
item['positives']['score'].append(scores_i[idx])
except:
item['positives']['score'].append(scores_i[-1])
rst.append(item)
with open(f'{file}.jsonl', 'w') as jsonl_file:
for item in rst:
json_line = json.dumps(item, ensure_ascii=False)
jsonl_file.write(json_line + '\n')
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
print(args)
model = SharedBiEncoder(model_checkpoint=args.encoder,
representation=args.sentence_pooling_method,
fixed=True)
model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer if args.tokenizer else args.encoder)
reranker = RerankerForInference(model_checkpoint=args.cross_checkpoint)
reranker.to('cuda')
reranker_tokenizer = AutoTokenizer.from_pretrained(args.cross_tokenizer if args.cross_tokenizer else args.cross_checkpoint)
csv_file = True
if args.data_file:
if args.data_file.endswith("jsonl"):
test_data = []
with open(args.data_file, 'r') as jsonl_file:
for line in jsonl_file:
temp = json.loads(line)
test_data.append(temp)
csv_file=False
elif args.data_file.endswith("json"):
csv_file=False
with open(args.data_file, 'r') as json_file:
test_data = json.load(json_file)
elif args.data_file.endswith("csv"):
test_data = pd.read_csv(args.data_file)
elif args.data_type == 'eval':
test_data = pd.read_csv(args.data_path + "/tval.csv")
elif args.data_type == 'train':
test_data = pd.read_csv(args.data_path + "/ttrain.csv")
elif args.data_type == 'all':
data1 = pd.read_csv(args.data_path + "/ttrain.csv")
data2 = pd.read_csv(args.data_path + "/ttest.csv")
data3 = pd.read_csv(args.data_path + "/tval.csv")
test_data = pd.concat([data1, data3, data2], ignore_index=True)
else:
test_data = pd.read_csv(args.data_path + "/ttest.csv")
corpus_data = pd.read_csv(args.corpus_file)
#dcorpus = pd.DataFrame(corpus_data)
#pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
#dcorpus["full_text"] = dcorpus.parallel_apply(concat_str, axis=1)
corpus = corpus_data['tokenized_text'].tolist()
if csv_file:
ans_ids = []
ground_ids = []
org_questions = test_data['question'].tolist()
questions = test_data['tokenized_question'].tolist()
for i in range(len(test_data)):
ans_ids.append(json.loads(test_data['best_ans_id'][i]))
ground_ids.append(json.loads(test_data['ans_id'][i]))
ground_truths = []
for sample in ans_ids:
temp = [corpus_data['law_id'][y] + "_" + str(corpus_data['article_id'][y]) for y in sample]
ground_truths.append(temp)
else:
ground_truths = []
ground_ids = []
org_questions = [sample['question'] for sample in test_data]
questions = [tokenize(preprocess_question(sample['question'], remove_end_phrase=False)) for sample in test_data]
for sample in test_data:
try:
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevance_articles']]
tempp = [it['ans_id'] for it in sample['relevance_articles']]
except:
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevant_articles']]
tempp = [it['ans_id'] for it in sample['relevant_articles']]
ground_truths.append(temp)
ground_ids.append(tempp)
faiss_index = index(
model=model,
tokenizer=tokenizer,
corpus=corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
tokenizer=tokenizer,
questions=questions,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results, retrieval_ids = [], []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
rst = []
for x in indice:
temp = corpus_data['law_id'][x] + "_" + str(corpus_data['article_id'][x])
if temp not in rst:
rst.append(temp)
retrieval_results.append(rst)
retrieval_ids.append(indice)
rerank_ids, rerank_scores = rerank(reranker, reranker_tokenizer, questions, corpus, retrieval_ids, args.cross_batch_size, args.cross_max_length, args.top_k)
if args.bi_data:
save_bi_data(questions, ground_ids, rerank_ids, rerank_scores, args.data_type, org_questions)
metrics = check(ground_ids, retrieval_ids)
print(metrics)
metrics = evaluate(retrieval_results, ground_truths)
print(metrics)
metrics = check(ground_ids, rerank_ids, cutoffs=[1,5,10,30])
print(metrics)
if __name__ == "__main__":
main()