File size: 4,326 Bytes
f96a150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import sys
from Reasoning.mor4path import MOR4Path
from Planning.model import Planner
from prepare_rerank import prepare_trajectories
from tqdm import tqdm
import os
import pickle as pkl
import torch
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from stark_qa import load_qa, load_skb
import torch.nn as nn
# make model_name a argument
parser = ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="mag")
# text retriever name
parser.add_argument("--text_retriever_name", type=str, default="bm25")
parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") # contriever for prime, ada for amazon and mag
# mod
parser.add_argument("--mod", type=str, default="test", help="train, valid, test")
# device
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
if __name__ == "__main__":
args = parser.parse_args()
dataset_name = args.dataset_name
scorer_name = args.scorer_name
text_retriever_name = args.text_retriever_name
skb = load_skb(dataset_name)
qa = load_qa(dataset_name, human_generated_eval=False)
eval_metrics = [
"mrr",
"map",
"rprecision",
"recall@5",
"recall@10",
"recall@20",
"recall@50",
"recall@100",
"hit@1",
"hit@3",
"hit@5",
"hit@10",
"hit@20",
"hit@50",
]
mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb)
reasoner = Planner(dataset_name)
outputs = []
topk = 100
split_idx = qa.get_idx_split(test_ratio=1.0)
mod = args.mod
all_indices = split_idx[mod].tolist()
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
count = 0
# ***** planning *****
# if the plan cache exists, load it
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
if os.path.exists(plan_cache_path):
with open(plan_cache_path, 'rb') as f:
plan_output_list = pkl.load(f)
else:
plan_output_list = []
for idx, i in enumerate(tqdm(all_indices)):
plan_output = {}
query, q_id, ans_ids, _ = qa[i]
rg = reasoner(query)
plan_output['query'] = query
plan_output['q_id'] = q_id
plan_output['ans_ids'] = ans_ids
plan_output['rg'] = rg
plan_output_list.append(plan_output)
# save plan_output_list
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True)
with open(plan_cache_path, 'wb') as f:
pkl.dump(plan_output_list, f)
# ***** Reasoning *****
for idx, i in enumerate(tqdm(all_indices)):
query = plan_output_list[idx]['query']
q_id = plan_output_list[idx]['q_id']
ans_ids = plan_output_list[idx]['ans_ids']
rg = plan_output_list[idx]['rg']
output = mor_path(query, q_id, ans_ids, rg, args)
ans_ids = torch.LongTensor(ans_ids)
pred_dict = output['pred_dict']
result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics)
result["idx"], result["query_id"] = i, q_id
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
:topk
]
].tolist()
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
output['q_id'] = q_id
outputs.append(output)
count += 1
# for metric in eval_metrics:
# print(
# f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(all_indices)][metric])}"
# )
print(f"MOR count: {mor_path.mor_count}")
# prepare trajectories and save
bm25 = mor_path.text_retriever
test_data = prepare_trajectories(dataset_name, bm25, skb, outputs)
save_path = f"{dataset_name}_{mod}.pkl"
with open(save_path, 'wb') as f:
pkl.dump(test_data, f)
|