File size: 4,326 Bytes
f96a150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import sys


from Reasoning.mor4path import MOR4Path
from Planning.model import Planner
from prepare_rerank import prepare_trajectories
from tqdm import tqdm
import os
import pickle as pkl
import torch
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from stark_qa import load_qa, load_skb
import torch.nn as nn
        


# make model_name a argument
parser = ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="mag")
# text retriever name
parser.add_argument("--text_retriever_name", type=str, default="bm25")
parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") # contriever for prime, ada for amazon and mag
# mod
parser.add_argument("--mod", type=str, default="test", help="train, valid, test")
# device
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")




if __name__ == "__main__":
    
    args = parser.parse_args()
    dataset_name = args.dataset_name
    scorer_name = args.scorer_name
    text_retriever_name = args.text_retriever_name
    skb = load_skb(dataset_name)
    qa = load_qa(dataset_name, human_generated_eval=False)
    
    eval_metrics = [
        "mrr",
        "map",
        "rprecision",
        "recall@5",
        "recall@10",
        "recall@20",
        "recall@50",
        "recall@100",
        "hit@1",
        "hit@3",
        "hit@5",
        "hit@10",
        "hit@20",
        "hit@50",
    ]
    
    mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb)
    reasoner = Planner(dataset_name)
    outputs = []
    topk = 100
    split_idx = qa.get_idx_split(test_ratio=1.0)
    mod = args.mod
    all_indices = split_idx[mod].tolist()
    eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
    
    count = 0
    
    # ***** planning *****
    # if the plan cache exists, load it
    plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
    if os.path.exists(plan_cache_path):
        with open(plan_cache_path, 'rb') as f:
            plan_output_list = pkl.load(f)
    else:
        plan_output_list = []
        for idx, i in enumerate(tqdm(all_indices)):
            plan_output = {}
            query, q_id, ans_ids, _ = qa[i]
            rg = reasoner(query)  
    
            plan_output['query'] = query
            plan_output['q_id'] = q_id
            plan_output['ans_ids'] = ans_ids
            plan_output['rg'] = rg
            plan_output_list.append(plan_output)
        # save plan_output_list
        plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
        os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True)
        with open(plan_cache_path, 'wb') as f:
            pkl.dump(plan_output_list, f)
    
    
    # ***** Reasoning *****
    for idx, i in enumerate(tqdm(all_indices)):
        
        query = plan_output_list[idx]['query']
        q_id = plan_output_list[idx]['q_id']
        ans_ids = plan_output_list[idx]['ans_ids']
        rg = plan_output_list[idx]['rg']
        
        
        output = mor_path(query, q_id, ans_ids, rg, args)
        
        ans_ids = torch.LongTensor(ans_ids)
        
        pred_dict = output['pred_dict']
        result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics)
        
        result["idx"], result["query_id"] = i, q_id
        result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
            torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
                :topk
            ]
        ].tolist()

        eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
        
        output['q_id'] = q_id
        outputs.append(output)
        
        count += 1 
        
                
        # for metric in eval_metrics:
        #     print(
        #         f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(all_indices)][metric])}"
        #     )
    
    
    print(f"MOR count: {mor_path.mor_count}")
    

    # prepare trajectories and save
    bm25 = mor_path.text_retriever
    test_data = prepare_trajectories(dataset_name, bm25, skb, outputs)
    save_path = f"{dataset_name}_{mod}.pkl"
    with open(save_path, 'wb') as f:
        pkl.dump(test_data, f)