File size: 8,438 Bytes
f96a150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
from Reasoning.text_retrievers.contriever import Contriever
from Reasoning.text_retrievers.ada import Ada
from stark_qa import load_qa, load_skb
import pickle as pkl
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
model_name = f"bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
encoder = BertModel.from_pretrained(model_name)
def get_bm25_scores(dataset_name, bm25, outputs):
new_outputs = []
# use tqdm to visualize the progress
for i in range(len(outputs)):
query, q_id, ans_ids = outputs[i]['query'], outputs[i]['q_id'], outputs[i]['ans_ids']
paths= outputs[i]['paths']
rg = outputs[i]['rg']
if dataset_name == 'prime':
new_path_dict = paths
# make new path dict and remove the -1 from the path
new_path_dict = {}
for key in paths.keys():
new_path = [x for x in paths[key] if x != -1]
new_path_dict[key] = new_path
# collect all values of the path without the first element
candidates_ids = []
for key in new_path_dict.keys():
candidates_ids = list(set(candidates_ids))
# get the bm25 score
bm_score_dict = bm25.score(query, q_id, candidate_ids=candidates_ids)
outputs[i]['bm_score_dict'] = bm_score_dict
# replace -1 in the bm_vector_dict with the bm_score
bm_vector_dict = outputs[i]['bm_vector_dict']
for key in bm_vector_dict.keys():
if -1 in bm_vector_dict[key]:
path = new_path_dict[key]
assert len(path) == len(bm_vector_dict[key])
bm_vector_dict[key] = [bm_score_dict[path[j]] if x == -1 else x for j, x in enumerate(bm_vector_dict[key])]
outputs[i]['bm_vector_dict'] = bm_vector_dict
# fix length of paths in prime
if dataset_name == 'prime':
max_len = 3
new_paths = {}
for key in paths:
new_path = paths[key]
if len(paths[key]) < max_len:
new_path = [-1] * (max_len - len(paths[key])) + paths[key]
elif len(paths[key]) > max_len:
new_path = paths[key][-max_len:]
new_paths[key] = new_path
# assign the new path to the paths
outputs[i]['paths'] = new_paths
return new_outputs
def prepare_score_vector_dict(raw_data):
# make the score_vector_dict: [bm_score, bm_score, bm_score, ada_score/contriver_score]
for i in range(len(raw_data)):
# get the pred_dict
pred_dict = raw_data[i]['pred_dict']
# get the bm_vector_dict
bm_vector_dict = raw_data[i]['bm_vector_dict']
# initialize the score_vector_dict
raw_data[i]['score_vector_dict'] = {}
# add the value of pred_dict to the end of the bm_vector_dict
for key in pred_dict:
# get the bm_score, last element of the bm_vector_dict
bm_vector = bm_vector_dict[key]
# get the ranking score
rk_score = pred_dict[key]
# make the score_vector_dict
score_vector = bm_vector + [rk_score]
# check the length of the score_vector, if less than 4, pad with 0 at the beginning
if len(score_vector) < 4:
score_vector = [0] * (4 - len(score_vector)) + score_vector
elif len(score_vector) > 4:
score_vector = score_vector[-4:]
# make the score_vector_dict
raw_data[i]['score_vector_dict'][key] = score_vector
return raw_data
def prepare_text_emb_symb_enc(raw_data, skb):
# add the text_emb to the raw_data
text2emb_list = []
text2emb_dict = {}
symbolic_encode_dict = {
3: [0, 1, 1],
2: [2, 0, 1],
1: [2, 2, 0],
for i in range(len(raw_data)):
# get the paths
paths = raw_data[i]['paths']
preds = raw_data[i]['pred_dict']
assert len(paths) == len(preds)
# initialize the text_emb_dict
raw_data[i]['text_emb_dict'] = {}
# initialize the symb_enc_dict
raw_data[i]['symb_enc_dict'] = {}
for key in paths:
# get the path
path = paths[key]
# make uniquee text_emb_path and make dict
text_path_li = [skb.get_node_type_by_id(node_id) if node_id != -1 else "padding" for node_id in path]
text_path_str = " ".join(text_path_li)
if text_path_str not in text2emb_list:
text2emb_dict[text_path_str] = -1
# assgin thte text_path to the raw_data
raw_data[i]['text_emb_dict'][key] = text_path_str
# ***** make the symb_enc_dict *****
# number of non -1 in the path
num_non_1 = len([p for p in path if p != -1])
# get the symbolic encoding
symb_enc = symbolic_encode_dict[num_non_1]
# make the symb_enc_dict
raw_data[i]['symb_enc_dict'][key] = symb_enc
# ***** get the text2emb_dict embeddings *****
for key in text2emb_dict.keys():
# get the tokens for the node type using th tokenizer
text_enc = tokenizer(key, return_tensors='pt')['input_ids']
outputs = encoder(text_enc)
last_hidden_states = outputs.last_hidden_state.mean(dim=1)
text2emb_dict[key] = last_hidden_states.detach()
new_data = {'data': raw_data, 'text2emb_dict': text2emb_dict}
return new_data
def prepare_trajectories(dataset_name, bm25, skb, outputs):
# get the bm25 scores
new_outputs = get_bm25_scores(dataset_name, bm25, outputs) # return list
# prepare the score_vector_dict
new_outputs = prepare_score_vector_dict(new_outputs) # return list
# prepare the text_emb and symb_enc_dict
new_data = prepare_text_emb_symb_enc(new_outputs, skb) # return dict
return new_data
def get_contriever_scores(dataset_name, mod, skb, path):
with open(path, 'rb') as f:
data = pkl.load(f)
raw_data = data['data']
qa = load_qa(dataset_name, human_generated_eval=False)
contriever = Contriever(skb, dataset_name, device='cuda')
split_idx = qa.get_idx_split(test_ratio=1.0)
all_indices = split_idx[mod].tolist()
# use tqdm to visualize the progress
for idx, i in enumerate(tqdm(all_indices)):
query, q_id, ans_ids, _ = qa[i]
assert query == raw_data[idx]['query']
pred_ids = list(raw_data[idx]['pred_dict'].keys())
candidates_ids = list(set(pred_ids))
# get contriever score
contriever_score_dict = contriever.score(query, q_id, candidate_ids=candidates_ids)
raw_data[idx]['contriever_score_dict'] = contriever_score_dict
data['data'] = raw_data
with open(path, 'wb') as f:
pkl.dump(data, f)
def get_ada_scores(dataset_name, mod, skb, path):
with open(path, 'rb') as f:
data = pkl.load(f)
raw_data = data['data']
qa = load_qa(dataset_name, human_generated_eval=False)
ada = Ada(skb, dataset_name, device='cuda')
split_idx = qa.get_idx_split(test_ratio=1.0)
all_indices = split_idx[mod].tolist()
# use tqdm to visualize the progress
for idx, i in enumerate(tqdm(all_indices)):
query, q_id, ans_ids, _ = qa[i]
assert query == raw_data[idx]['query']
pred_ids = list(raw_data[idx]['pred_dict'].keys())
candidates_ids = list(set(pred_ids))
# get ada score
ada_score_dict = ada.score(query, q_id, candidate_ids=candidates_ids)
raw_data[idx]['ada_score_dict'] = ada_score_dict
data['data'] = raw_data
with open(path, 'wb') as f:
pkl.dump(data, f)
if __name__ == '__main__':
print(f"Test prepare_rerank")