scripts
Browse files- README.md +78 -0
- eval.py +143 -0
- eval_mor.sh +21 -0
- get_emb.py +142 -0
- mor_env.yml +327 -0
- prepare_rerank.py +245 -0
- requirements.txt +31 -0
- run_reasoning.sh +23 -0
- train_planner.sh +7 -0
- train_reranker.sh +13 -0
README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MoR
|
2 |
+
|
3 |
+
# Running the Evaluation and Reranking Script
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
To set up the environment, you can install dependencies using Conda or pip:
|
7 |
+
|
8 |
+
### Using Conda
|
9 |
+
```bash
|
10 |
+
conda env create -f mor_env.yml
|
11 |
+
conda activate your_env_name # Replace with actual environment name
|
12 |
+
```
|
13 |
+
|
14 |
+
### Using pip
|
15 |
+
```bash
|
16 |
+
pip install -r requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
### Checkpoints and embeddings download
|
20 |
+
Before running the inference, please go to https://drive.google.com/drive/folders/1ldOYiyrIaZ3AVAKAmNeP0ZWfD3DLZu9D?usp=drive_link
|
21 |
+
|
22 |
+
(1) download the "checkpoints" and put it under the directory MoR/Planning/
|
23 |
+
|
24 |
+
(2) download the "data" and put it under the directory MoR/Reasoning/
|
25 |
+
|
26 |
+
(2) download the "model_checkpoint" and put it under the directory MoR/Reasoning/text_retrievers/
|
27 |
+
|
28 |
+
|
29 |
+
## Inference
|
30 |
+
To run the inference script, execute the following command in the terminal:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
bash eval_mor.sh
|
34 |
+
```
|
35 |
+
|
36 |
+
This script will automatically process three datasets using the pre-trained planning graph generator and the pre-trained reranker.
|
37 |
+
|
38 |
+
## Training (Train MoR from Scratch)
|
39 |
+
### Step1: Training the planning graph generator
|
40 |
+
|
41 |
+
```bash
|
42 |
+
bash train_planner.sh
|
43 |
+
```
|
44 |
+
|
45 |
+
### Step2: Train mixed traversal to collect candidates (note: there is no training process for reasoning)
|
46 |
+
|
47 |
+
```bash
|
48 |
+
bash run_reasoning.sh
|
49 |
+
```
|
50 |
+
|
51 |
+
### Step3: Training the reranker
|
52 |
+
|
53 |
+
```bash
|
54 |
+
bash train_reranker.sh
|
55 |
+
```
|
56 |
+
|
57 |
+
## Generating training data of Planner
|
58 |
+
### We provide codes to generate your own training data to finetune the Planner by using different LLMs.
|
59 |
+
#### If you are using Azure API
|
60 |
+
|
61 |
+
```bash
|
62 |
+
python script.py --model "model_name" \
|
63 |
+
--dataset_name "dataset_name" \
|
64 |
+
--azure_api_key "your_azure_key" \
|
65 |
+
--azure_endpoint "your_azure_endpoint" \
|
66 |
+
--azure_api_version "your_azure_version"
|
67 |
+
|
68 |
+
```
|
69 |
+
|
70 |
+
#### If you are using OpenAI API
|
71 |
+
|
72 |
+
```bash
|
73 |
+
python script.py --model "model_name" \
|
74 |
+
--dataset_name "dataset_name" \
|
75 |
+
--openai_api_key "your_openai_key" \
|
76 |
+
--openai_endpoint "your_openai_endpoint"
|
77 |
+
|
78 |
+
```
|
eval.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
from Reasoning.mor4path import MOR4Path
|
6 |
+
from Planning.model import Planner
|
7 |
+
from prepare_rerank import prepare_trajectories
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
import pickle as pkl
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
from stark_qa import load_qa, load_skb
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# make model_name a argument
|
21 |
+
parser = ArgumentParser()
|
22 |
+
parser.add_argument("--dataset_name", type=str, default="mag")
|
23 |
+
# text retriever name
|
24 |
+
parser.add_argument("--text_retriever_name", type=str, default="bm25")
|
25 |
+
parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") # contriever for prime, ada for amazon and mag
|
26 |
+
# mod
|
27 |
+
parser.add_argument("--mod", type=str, default="test", help="train, valid, test")
|
28 |
+
# device
|
29 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
dataset_name = args.dataset_name
|
38 |
+
scorer_name = args.scorer_name
|
39 |
+
text_retriever_name = args.text_retriever_name
|
40 |
+
skb = load_skb(dataset_name)
|
41 |
+
qa = load_qa(dataset_name, human_generated_eval=False)
|
42 |
+
|
43 |
+
eval_metrics = [
|
44 |
+
"mrr",
|
45 |
+
"map",
|
46 |
+
"rprecision",
|
47 |
+
"recall@5",
|
48 |
+
"recall@10",
|
49 |
+
"recall@20",
|
50 |
+
"recall@50",
|
51 |
+
"recall@100",
|
52 |
+
"hit@1",
|
53 |
+
"hit@3",
|
54 |
+
"hit@5",
|
55 |
+
"hit@10",
|
56 |
+
"hit@20",
|
57 |
+
"hit@50",
|
58 |
+
]
|
59 |
+
|
60 |
+
mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb)
|
61 |
+
reasoner = Planner(dataset_name)
|
62 |
+
outputs = []
|
63 |
+
topk = 100
|
64 |
+
split_idx = qa.get_idx_split(test_ratio=1.0)
|
65 |
+
mod = args.mod
|
66 |
+
all_indices = split_idx[mod].tolist()
|
67 |
+
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
|
68 |
+
|
69 |
+
count = 0
|
70 |
+
|
71 |
+
# ***** planning *****
|
72 |
+
# if the plan cache exists, load it
|
73 |
+
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
|
74 |
+
if os.path.exists(plan_cache_path):
|
75 |
+
with open(plan_cache_path, 'rb') as f:
|
76 |
+
plan_output_list = pkl.load(f)
|
77 |
+
else:
|
78 |
+
plan_output_list = []
|
79 |
+
for idx, i in enumerate(tqdm(all_indices)):
|
80 |
+
plan_output = {}
|
81 |
+
query, q_id, ans_ids, _ = qa[i]
|
82 |
+
rg = reasoner(query)
|
83 |
+
|
84 |
+
plan_output['query'] = query
|
85 |
+
plan_output['q_id'] = q_id
|
86 |
+
plan_output['ans_ids'] = ans_ids
|
87 |
+
plan_output['rg'] = rg
|
88 |
+
plan_output_list.append(plan_output)
|
89 |
+
# save plan_output_list
|
90 |
+
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
|
91 |
+
os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True)
|
92 |
+
with open(plan_cache_path, 'wb') as f:
|
93 |
+
pkl.dump(plan_output_list, f)
|
94 |
+
|
95 |
+
|
96 |
+
# ***** Reasoning *****
|
97 |
+
for idx, i in enumerate(tqdm(all_indices)):
|
98 |
+
|
99 |
+
query = plan_output_list[idx]['query']
|
100 |
+
q_id = plan_output_list[idx]['q_id']
|
101 |
+
ans_ids = plan_output_list[idx]['ans_ids']
|
102 |
+
rg = plan_output_list[idx]['rg']
|
103 |
+
|
104 |
+
|
105 |
+
output = mor_path(query, q_id, ans_ids, rg, args)
|
106 |
+
|
107 |
+
ans_ids = torch.LongTensor(ans_ids)
|
108 |
+
|
109 |
+
pred_dict = output['pred_dict']
|
110 |
+
result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics)
|
111 |
+
|
112 |
+
result["idx"], result["query_id"] = i, q_id
|
113 |
+
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
|
114 |
+
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
|
115 |
+
:topk
|
116 |
+
]
|
117 |
+
].tolist()
|
118 |
+
|
119 |
+
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
|
120 |
+
|
121 |
+
output['q_id'] = q_id
|
122 |
+
outputs.append(output)
|
123 |
+
|
124 |
+
count += 1
|
125 |
+
|
126 |
+
|
127 |
+
# for metric in eval_metrics:
|
128 |
+
# print(
|
129 |
+
# f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(all_indices)][metric])}"
|
130 |
+
# )
|
131 |
+
|
132 |
+
|
133 |
+
print(f"MOR count: {mor_path.mor_count}")
|
134 |
+
|
135 |
+
|
136 |
+
# prepare trajectories and save
|
137 |
+
bm25 = mor_path.text_retriever
|
138 |
+
test_data = prepare_trajectories(dataset_name, bm25, skb, outputs)
|
139 |
+
save_path = f"{dataset_name}_{mod}.pkl"
|
140 |
+
with open(save_path, 'wb') as f:
|
141 |
+
pkl.dump(test_data, f)
|
142 |
+
|
143 |
+
|
eval_mor.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/bin/bash
|
2 |
+
|
3 |
+
|
4 |
+
datasets=("mag" "amazon" "prime")
|
5 |
+
# Define scorer_name mapping using an associative array
|
6 |
+
declare -A dataset_scorer_map=(
|
7 |
+
[mag]="ada"
|
8 |
+
[amazon]="ada"
|
9 |
+
[prime]="contriever"
|
10 |
+
)
|
11 |
+
|
12 |
+
for dataset in "${datasets[@]}"; do
|
13 |
+
# Get the corresponding scorer_name for the dataset
|
14 |
+
scorer_name="${dataset_scorer_map[$dataset]}"
|
15 |
+
echo "Processing dataset: $dataset with scorer: $scorer_name"
|
16 |
+
python eval.py --dataset_name "$dataset" --scorer_name "$scorer_name" --mod "test"
|
17 |
+
|
18 |
+
cd Reranking
|
19 |
+
python rerank.py --dataset_name "$dataset"
|
20 |
+
cd ..
|
21 |
+
done
|
get_emb.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings
|
12 |
+
|
13 |
+
sys.path.append('.')
|
14 |
+
from stark_qa import load_skb, load_qa
|
15 |
+
from stark_qa.tools.api import get_api_embeddings
|
16 |
+
from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
|
17 |
+
from models.model import get_embeddings
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
def parse_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
|
24 |
+
# Dataset and embedding model selection
|
25 |
+
parser.add_argument('--dataset', default='prime', choices=['amazon', 'prime', 'mag'])
|
26 |
+
parser.add_argument('--emb_model', default='contriever',
|
27 |
+
choices=[
|
28 |
+
'text-embedding-ada-002',
|
29 |
+
'text-embedding-3-small',
|
30 |
+
'text-embedding-3-large',
|
31 |
+
'voyage-large-2-instruct',
|
32 |
+
'GritLM/GritLM-7B',
|
33 |
+
'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp',
|
34 |
+
'all-mpnet-base-v2' # for sentence transformer
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
# Mode settings
|
39 |
+
parser.add_argument('--mode', default='query', choices=['doc', 'query'])
|
40 |
+
|
41 |
+
# Path settings
|
42 |
+
parser.add_argument("--data_dir", default="data/", type=str)
|
43 |
+
parser.add_argument("--emb_dir", default="emb/", type=str)
|
44 |
+
|
45 |
+
# Text settings
|
46 |
+
parser.add_argument('--add_rel', action='store_true', default=False, help='add relation to the text')
|
47 |
+
parser.add_argument('--compact', action='store_true', default=False, help='make the text compact when input to the model')
|
48 |
+
|
49 |
+
# Evaluation settings
|
50 |
+
parser.add_argument("--human_generated_eval", action="store_true", help="if mode is `query`, then generating query embeddings on human generated evaluation split")
|
51 |
+
|
52 |
+
# Batch and node settings
|
53 |
+
parser.add_argument("--batch_size", default=1024, type=int)
|
54 |
+
|
55 |
+
# encode kwargs
|
56 |
+
parser.add_argument("--n_max_nodes", default=None, type=int, metavar="ENCODE")
|
57 |
+
parser.add_argument("--device", default=None, type=str, metavar="ENCODE")
|
58 |
+
parser.add_argument("--peft_model_name", default=None, type=str, help="llm2vec pdft model", metavar="ENCODE")
|
59 |
+
parser.add_argument("--instruction", type=str, help="gritl/llm2vec instruction", metavar="ENCODE")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
# Create encode_kwargs based on the custom metavar "ENCODE"
|
64 |
+
encode_kwargs = {k: v for k, v in vars(args).items() if v is not None and parser._option_string_actions[f'--{k}'].metavar == "ENCODE"}
|
65 |
+
|
66 |
+
return args, encode_kwargs
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
args, encode_kwargs = parse_args()
|
71 |
+
args.human_generated_eval = False
|
72 |
+
mode_surfix = '_human_generated_eval' if args.human_generated_eval and args.mode == 'query' else ''
|
73 |
+
mode_surfix += '_no_rel' if not args.add_rel else ''
|
74 |
+
mode_surfix += '_no_compact' if not args.compact else ''
|
75 |
+
emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f'{args.mode}{mode_surfix}')
|
76 |
+
csv_cache = osp.join(args.data_dir, args.dataset, f'{args.mode}{mode_surfix}.csv')
|
77 |
+
|
78 |
+
print(f'Embedding directory: {emb_dir}')
|
79 |
+
os.makedirs(emb_dir, exist_ok=True)
|
80 |
+
os.makedirs(os.path.dirname(csv_cache), exist_ok=True)
|
81 |
+
|
82 |
+
if args.mode == 'doc':
|
83 |
+
skb = load_skb(args.dataset)
|
84 |
+
lst = skb.candidate_ids
|
85 |
+
emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt')
|
86 |
+
if args.mode == 'query':
|
87 |
+
qa_dataset = load_qa(args.dataset, human_generated_eval=args.human_generated_eval)
|
88 |
+
lst = [qa_dataset[i][1] for i in range(len(qa_dataset))]
|
89 |
+
emb_path = osp.join(emb_dir, f'query_emb_dict.pt')
|
90 |
+
random.shuffle(lst)
|
91 |
+
|
92 |
+
# Load existing embeddings if they exist
|
93 |
+
if osp.exists(emb_path):
|
94 |
+
emb_dict = torch.load(emb_path)
|
95 |
+
exist_emb_indices = list(emb_dict.keys())
|
96 |
+
print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}')
|
97 |
+
else:
|
98 |
+
emb_dict = {}
|
99 |
+
exist_emb_indices = []
|
100 |
+
|
101 |
+
# Load existing document cache if it exists (only for doc mode)
|
102 |
+
if args.mode == 'doc' and osp.exists(csv_cache):
|
103 |
+
df = pd.read_csv(csv_cache)
|
104 |
+
cache_dict = dict(zip(df['index'], df['text']))
|
105 |
+
|
106 |
+
# Ensure that the indices in the cache match the expected indices
|
107 |
+
assert set(cache_dict.keys()) == set(lst), 'Indices in cache do not match the candidate indices.'
|
108 |
+
|
109 |
+
indices = list(set(lst) - set(exist_emb_indices))
|
110 |
+
texts = [cache_dict[idx] for idx in tqdm(indices, desc="Filtering docs for new embeddings")]
|
111 |
+
else:
|
112 |
+
indices = lst
|
113 |
+
texts = [qa_dataset.get_query_by_qid(idx) if args.mode == 'query'
|
114 |
+
else skb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact) for idx in tqdm(indices, desc="Gathering docs")]
|
115 |
+
if args.mode == 'doc':
|
116 |
+
df = pd.DataFrame({'index': indices, 'text': texts})
|
117 |
+
df.to_csv(csv_cache, index=False)
|
118 |
+
|
119 |
+
print(f'Generating embeddings for {len(texts)} texts...')
|
120 |
+
if args.emb_model == 'contriever':
|
121 |
+
encoder, tokenizer = get_contriever(dataset_name=args.dataset)
|
122 |
+
for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
|
123 |
+
batch_texts = texts[i:i+args.batch_size]
|
124 |
+
batch_embs = get_contriever_embeddings(batch_texts, encoder=encoder, tokenizer=tokenizer, device='cuda')
|
125 |
+
batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
|
126 |
+
|
127 |
+
batch_indices = indices[i:i+args.batch_size]
|
128 |
+
for idx, emb in zip(batch_indices, batch_embs):
|
129 |
+
emb_dict[idx] = emb.view(1, -1)
|
130 |
+
else:
|
131 |
+
|
132 |
+
for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
|
133 |
+
batch_texts = texts[i:i+args.batch_size]
|
134 |
+
batch_embs = get_embeddings(batch_texts, args.emb_model, **encode_kwargs)
|
135 |
+
batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
|
136 |
+
|
137 |
+
batch_indices = indices[i:i+args.batch_size]
|
138 |
+
for idx, emb in zip(batch_indices, batch_embs):
|
139 |
+
emb_dict[idx] = emb.view(1, -1)
|
140 |
+
|
141 |
+
torch.save(emb_dict, emb_path)
|
142 |
+
print(f'Saved {len(emb_dict)} embeddings to {emb_path}!')
|
mor_env.yml
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: your_env_name
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- bzip2=1.0.8=h5eee18b_6
|
8 |
+
- ca-certificates=2024.11.26=h06a4308_0
|
9 |
+
- debugpy=1.6.7=py311h6a678d5_0
|
10 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
11 |
+
- ipykernel=6.29.5=py311h06a4308_0
|
12 |
+
- jedi=0.19.2=py311h06a4308_0
|
13 |
+
- jupyter_client=8.6.0=py311h06a4308_0
|
14 |
+
- jupyter_core=5.7.2=py311h06a4308_0
|
15 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
16 |
+
- libffi=3.4.4=h6a678d5_1
|
17 |
+
- libgcc-ng=11.2.0=h1234567_1
|
18 |
+
- libgomp=11.2.0=h1234567_1
|
19 |
+
- libsodium=1.0.18=h7b6447c_0
|
20 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
21 |
+
- libuuid=1.41.5=h5eee18b_0
|
22 |
+
- ncurses=6.4=h6a678d5_0
|
23 |
+
- nest-asyncio=1.6.0=py311h06a4308_0
|
24 |
+
- openssl=3.0.15=h5eee18b_0
|
25 |
+
- parso=0.8.4=py311h06a4308_0
|
26 |
+
- pip=24.2=py311h06a4308_0
|
27 |
+
- prompt_toolkit=3.0.43=hd3eb1b0_0
|
28 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
29 |
+
- pure_eval=0.2.2=pyhd3eb1b0_0
|
30 |
+
- python=3.11.11=he870216_0
|
31 |
+
- python-dateutil=2.9.0post0=py311h06a4308_2
|
32 |
+
- pyzmq=25.1.2=py311h6a678d5_0
|
33 |
+
- readline=8.2=h5eee18b_0
|
34 |
+
- setuptools=75.1.0=py311h06a4308_0
|
35 |
+
- six=1.16.0=pyhd3eb1b0_1
|
36 |
+
- sqlite=3.45.3=h5eee18b_0
|
37 |
+
- stack_data=0.2.0=pyhd3eb1b0_0
|
38 |
+
- tk=8.6.14=h39e8969_0
|
39 |
+
- tornado=6.4.2=py311h5eee18b_0
|
40 |
+
- traitlets=5.14.3=py311h06a4308_0
|
41 |
+
- typing_extensions=4.11.0=py311h06a4308_0
|
42 |
+
- wheel=0.44.0=py311h06a4308_0
|
43 |
+
- xz=5.4.6=h5eee18b_1
|
44 |
+
- zeromq=4.3.5=h6a678d5_0
|
45 |
+
- zlib=1.2.13=h5eee18b_1
|
46 |
+
- pip:
|
47 |
+
- accelerate==1.1.1
|
48 |
+
- aiobotocore==2.15.2
|
49 |
+
- aiohappyeyeballs==2.4.4
|
50 |
+
- aiohttp==3.11.9
|
51 |
+
- aioitertools==0.12.0
|
52 |
+
- aiolimiter==1.2.0
|
53 |
+
- aiosignal==1.3.1
|
54 |
+
- anndata==0.11.1
|
55 |
+
- annotated-types==0.7.0
|
56 |
+
- anthropic==0.40.0
|
57 |
+
- anyascii==0.3.2
|
58 |
+
- anyio==4.6.2.post1
|
59 |
+
- argon2-cffi==23.1.0
|
60 |
+
- argon2-cffi-bindings==21.2.0
|
61 |
+
- array-api-compat==1.9.1
|
62 |
+
- arrow==1.3.0
|
63 |
+
- asttokens==3.0.0
|
64 |
+
- async-lru==2.0.4
|
65 |
+
- attrs==24.2.0
|
66 |
+
- babel==2.16.0
|
67 |
+
- backports-tarfile==1.2.0
|
68 |
+
- beautifulsoup4==4.12.3
|
69 |
+
- biopython==1.84
|
70 |
+
- biothings-client==0.3.1
|
71 |
+
- bitarray==3.0.0
|
72 |
+
- bitsandbytes==0.44.1
|
73 |
+
- bleach==6.2.0
|
74 |
+
- blinker==1.9.0
|
75 |
+
- bm25s==0.2.5
|
76 |
+
- botocore==1.35.36
|
77 |
+
- bs4==0.0.2
|
78 |
+
- cattrs==24.1.2
|
79 |
+
- cellxgene-census==1.15.0
|
80 |
+
- certifi==2024.8.30
|
81 |
+
- cffi==1.17.1
|
82 |
+
- charset-normalizer==3.4.0
|
83 |
+
- chembl-webresource-client==0.10.9
|
84 |
+
- click==8.1.7
|
85 |
+
- colbert==0.40
|
86 |
+
- colbert-ai==0.2.21
|
87 |
+
- comm==0.2.2
|
88 |
+
- contourpy==1.3.1
|
89 |
+
- contractions==0.1.73
|
90 |
+
- cryptography==44.0.0
|
91 |
+
- cut-cross-entropy==24.11.4
|
92 |
+
- cycler==0.12.1
|
93 |
+
- dataclasses==0.6
|
94 |
+
- datasets==3.1.0
|
95 |
+
- defusedxml==0.7.1
|
96 |
+
- dill==0.3.8
|
97 |
+
- distro==1.9.0
|
98 |
+
- docker-pycreds==0.4.0
|
99 |
+
- docstring-parser==0.16
|
100 |
+
- docutils==0.21.2
|
101 |
+
- easydict==1.13
|
102 |
+
- et-xmlfile==2.0.0
|
103 |
+
- evaluate==0.4.2
|
104 |
+
- executing==2.1.0
|
105 |
+
- fastjsonschema==2.21.1
|
106 |
+
- filelock==3.13.1
|
107 |
+
- flask==3.1.0
|
108 |
+
- fonttools==4.55.1
|
109 |
+
- fqdn==1.5.1
|
110 |
+
- frozenlist==1.5.0
|
111 |
+
- fsspec==2024.2.0
|
112 |
+
- fuzzywuzzy==0.18.0
|
113 |
+
- gdown==5.2.0
|
114 |
+
- gget==0.29.0
|
115 |
+
- git-python==1.0.3
|
116 |
+
- gitdb==4.0.11
|
117 |
+
- gitpython==3.1.43
|
118 |
+
- greenlet==3.1.1
|
119 |
+
- h11==0.14.0
|
120 |
+
- h5py==3.12.1
|
121 |
+
- hf-transfer==0.1.8
|
122 |
+
- httpcore==1.0.7
|
123 |
+
- httpx==0.28.0
|
124 |
+
- huggingface-hub==0.26.3
|
125 |
+
- icalendar==6.1.0
|
126 |
+
- idna==3.10
|
127 |
+
- importlib-metadata==8.5.0
|
128 |
+
- ipython==8.30.0
|
129 |
+
- ipywidgets==8.1.5
|
130 |
+
- isoduration==20.11.0
|
131 |
+
- itsdangerous==2.2.0
|
132 |
+
- jaraco-classes==3.4.0
|
133 |
+
- jaraco-context==6.0.1
|
134 |
+
- jaraco-functools==4.1.0
|
135 |
+
- jeepney==0.8.0
|
136 |
+
- jinja2==3.1.3
|
137 |
+
- jiter==0.8.0
|
138 |
+
- jmespath==1.0.1
|
139 |
+
- joblib==1.4.2
|
140 |
+
- json5==0.10.0
|
141 |
+
- jsonpatch==1.33
|
142 |
+
- jsonpointer==3.0.0
|
143 |
+
- jsonschema==4.23.0
|
144 |
+
- jsonschema-specifications==2024.10.1
|
145 |
+
- jupyter==1.1.1
|
146 |
+
- jupyter-console==6.6.3
|
147 |
+
- jupyter-events==0.11.0
|
148 |
+
- jupyter-lsp==2.2.5
|
149 |
+
- jupyter-server==2.15.0
|
150 |
+
- jupyter-server-terminals==0.5.3
|
151 |
+
- jupyterlab==4.3.4
|
152 |
+
- jupyterlab-pygments==0.3.0
|
153 |
+
- jupyterlab-server==2.27.3
|
154 |
+
- jupyterlab-widgets==3.0.13
|
155 |
+
- keyring==25.5.0
|
156 |
+
- kiwisolver==1.4.7
|
157 |
+
- langchain==0.3.9
|
158 |
+
- langchain-core==0.3.21
|
159 |
+
- langchain-text-splitters==0.3.2
|
160 |
+
- langdetect==1.0.9
|
161 |
+
- langsmith==0.1.147
|
162 |
+
- legacy-api-wrap==1.4.1
|
163 |
+
- levenshtein==0.26.1
|
164 |
+
- lightning-utilities==0.11.9
|
165 |
+
- littleutils==0.2.4
|
166 |
+
- llvmlite==0.43.0
|
167 |
+
- lxml==5.3.0
|
168 |
+
- markdown-it-py==3.0.0
|
169 |
+
- markupsafe==2.1.5
|
170 |
+
- matplotlib==3.9.3
|
171 |
+
- matplotlib-inline==0.1.7
|
172 |
+
- matplotlib-venn==1.1.1
|
173 |
+
- mdurl==0.1.2
|
174 |
+
- mistune==3.1.1
|
175 |
+
- moleculeace==3.0.0
|
176 |
+
- more-itertools==10.5.0
|
177 |
+
- mpmath==1.3.0
|
178 |
+
- multidict==6.1.0
|
179 |
+
- multiprocess==0.70.16
|
180 |
+
- mygene==3.2.2
|
181 |
+
- mysql-connector-python==9.1.0
|
182 |
+
- natsort==8.4.0
|
183 |
+
- nbclient==0.10.2
|
184 |
+
- nbconvert==7.16.6
|
185 |
+
- nbformat==5.10.4
|
186 |
+
- networkx==3.2.1
|
187 |
+
- nh3==0.2.19
|
188 |
+
- ninja==1.11.1.2
|
189 |
+
- nltk==3.9.1
|
190 |
+
- notebook==7.3.2
|
191 |
+
- notebook-shim==0.2.4
|
192 |
+
- numba==0.60.0
|
193 |
+
- numpy==1.26.4
|
194 |
+
- nvidia-cublas-cu12==12.4.5.8
|
195 |
+
- nvidia-cuda-cupti-cu12==12.4.127
|
196 |
+
- nvidia-cuda-nvrtc-cu12==12.4.127
|
197 |
+
- nvidia-cuda-runtime-cu12==12.4.127
|
198 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
199 |
+
- nvidia-cufft-cu12==11.2.1.3
|
200 |
+
- nvidia-curand-cu12==10.3.5.147
|
201 |
+
- nvidia-cusolver-cu12==11.6.1.9
|
202 |
+
- nvidia-cusparse-cu12==12.3.1.170
|
203 |
+
- nvidia-nccl-cu12==2.21.5
|
204 |
+
- nvidia-nvjitlink-cu12==12.4.127
|
205 |
+
- nvidia-nvtx-cu12==12.4.127
|
206 |
+
- ogb==1.3.6
|
207 |
+
- openai==1.56.1
|
208 |
+
- openpyxl==3.1.5
|
209 |
+
- orjson==3.10.12
|
210 |
+
- outdated==0.2.2
|
211 |
+
- overrides==7.7.0
|
212 |
+
- packaging==24.2
|
213 |
+
- pandas==2.2.3
|
214 |
+
- pandocfilters==1.5.1
|
215 |
+
- patsy==1.0.1
|
216 |
+
- peft==0.13.2
|
217 |
+
- pexpect==4.9.0
|
218 |
+
- pillow==10.2.0
|
219 |
+
- pkginfo==1.12.0
|
220 |
+
- platformdirs==4.3.6
|
221 |
+
- prometheus-client==0.21.1
|
222 |
+
- prompt-toolkit==3.0.48
|
223 |
+
- propcache==0.2.1
|
224 |
+
- protobuf==3.20.3
|
225 |
+
- psutil==6.1.0
|
226 |
+
- pure-eval==0.2.3
|
227 |
+
- pyahocorasick==2.1.0
|
228 |
+
- pyaml==24.9.0
|
229 |
+
- pyarrow==18.1.0
|
230 |
+
- pyarrow-hotfix==0.6
|
231 |
+
- pycparser==2.22
|
232 |
+
- pydantic==2.10.3
|
233 |
+
- pydantic-core==2.27.1
|
234 |
+
- pygments==2.18.0
|
235 |
+
- pynndescent==0.5.13
|
236 |
+
- pyparsing==3.2.0
|
237 |
+
- pysocks==1.7.1
|
238 |
+
- pytdc==1.1.1
|
239 |
+
- python-dotenv==1.0.1
|
240 |
+
- python-json-logger==3.2.1
|
241 |
+
- python-levenshtein==0.26.1
|
242 |
+
- pytz==2024.2
|
243 |
+
- pyyaml==6.0.2
|
244 |
+
- rapidfuzz==3.10.1
|
245 |
+
- rdkit==2023.9.6
|
246 |
+
- rdkit-pypi==2022.9.5
|
247 |
+
- readme-renderer==44.0
|
248 |
+
- referencing==0.36.2
|
249 |
+
- regex==2024.11.6
|
250 |
+
- requests==2.32.3
|
251 |
+
- requests-cache==1.2.1
|
252 |
+
- requests-toolbelt==1.0.0
|
253 |
+
- rfc3339-validator==0.1.4
|
254 |
+
- rfc3986==2.0.0
|
255 |
+
- rfc3986-validator==0.1.1
|
256 |
+
- rich==13.9.4
|
257 |
+
- rpds-py==0.22.3
|
258 |
+
- s3fs==2024.2.0
|
259 |
+
- safetensors==0.4.5
|
260 |
+
- scanpy==1.10.4
|
261 |
+
- scikit-learn==1.2.2
|
262 |
+
- scikit-optimize==0.10.2
|
263 |
+
- scipy==1.14.1
|
264 |
+
- seaborn==0.13.2
|
265 |
+
- secretstorage==3.3.3
|
266 |
+
- send2trash==1.8.3
|
267 |
+
- sentence-transformers==3.3.1
|
268 |
+
- sentencepiece==0.2.0
|
269 |
+
- sentry-sdk==2.19.0
|
270 |
+
- session-info==1.0.0
|
271 |
+
- setproctitle==1.3.4
|
272 |
+
- shtab==1.7.1
|
273 |
+
- smmap==5.0.1
|
274 |
+
- sniffio==1.3.1
|
275 |
+
- somacore==1.0.11
|
276 |
+
- soupsieve==2.6
|
277 |
+
- sqlalchemy==2.0.36
|
278 |
+
- stack-data==0.6.3
|
279 |
+
- statsmodels==0.14.4
|
280 |
+
- stdlib-list==0.11.0
|
281 |
+
- sympy==1.13.1
|
282 |
+
- tenacity==9.0.0
|
283 |
+
- terminado==0.18.1
|
284 |
+
- textsearch==0.0.24
|
285 |
+
- threadpoolctl==3.5.0
|
286 |
+
- tiledb==0.29.1
|
287 |
+
- tiledbsoma==1.11.4
|
288 |
+
- tinycss2==1.4.0
|
289 |
+
- tokenizers==0.20.3
|
290 |
+
- torch==2.5.1+cu124
|
291 |
+
- torch-geometric==2.6.1
|
292 |
+
- torch-scatter==2.1.2+pt25cu124
|
293 |
+
- torchaudio==2.5.1+cu124
|
294 |
+
- torchmetrics==1.6.0
|
295 |
+
- torchvision==0.20.1+cu124
|
296 |
+
- tqdm==4.67.1
|
297 |
+
- transformers==4.46.3
|
298 |
+
- triton==3.1.0
|
299 |
+
- trl==0.12.1
|
300 |
+
- twine==6.0.1
|
301 |
+
- typeguard==4.4.1
|
302 |
+
- types-python-dateutil==2.9.0.20241206
|
303 |
+
- typing-extensions==4.12.2
|
304 |
+
- tyro==0.9.2
|
305 |
+
- tzdata==2024.2
|
306 |
+
- ujson==5.10.0
|
307 |
+
- umap-learn==0.5.7
|
308 |
+
- unsloth==2025.1.8
|
309 |
+
- unsloth-zoo==2025.1.5
|
310 |
+
- uri-template==1.3.0
|
311 |
+
- url-normalize==1.4.3
|
312 |
+
- urllib3==2.2.3
|
313 |
+
- voyageai==0.3.2
|
314 |
+
- wandb==0.18.7
|
315 |
+
- wcwidth==0.2.13
|
316 |
+
- webcolors==24.11.1
|
317 |
+
- webencodings==0.5.1
|
318 |
+
- websocket-client==1.8.0
|
319 |
+
- werkzeug==3.1.3
|
320 |
+
- widgetsnbextension==4.0.13
|
321 |
+
- wrapt==1.17.0
|
322 |
+
- xformers==0.0.28.post3
|
323 |
+
- xxhash==3.5.0
|
324 |
+
- yapf==0.43.0
|
325 |
+
- yarl==1.18.3
|
326 |
+
- zipp==3.21.0
|
327 |
+
prefix: /home/yongjia/.conda/envs/g_traversal
|
prepare_rerank.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Reasoning.text_retrievers.contriever import Contriever
|
2 |
+
from Reasoning.text_retrievers.ada import Ada
|
3 |
+
from stark_qa import load_qa, load_skb
|
4 |
+
|
5 |
+
import pickle as pkl
|
6 |
+
from tqdm import tqdm
|
7 |
+
from transformers import BertTokenizer, BertModel
|
8 |
+
|
9 |
+
model_name = f"bert-base-uncased"
|
10 |
+
|
11 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
12 |
+
encoder = BertModel.from_pretrained(model_name)
|
13 |
+
|
14 |
+
|
15 |
+
def get_bm25_scores(dataset_name, bm25, outputs):
|
16 |
+
|
17 |
+
new_outputs = []
|
18 |
+
# use tqdm to visualize the progress
|
19 |
+
for i in range(len(outputs)):
|
20 |
+
query, q_id, ans_ids = outputs[i]['query'], outputs[i]['q_id'], outputs[i]['ans_ids']
|
21 |
+
paths= outputs[i]['paths']
|
22 |
+
rg = outputs[i]['rg']
|
23 |
+
|
24 |
+
if dataset_name == 'prime':
|
25 |
+
new_path_dict = paths
|
26 |
+
else:
|
27 |
+
# make new path dict and remove the -1 from the path
|
28 |
+
new_path_dict = {}
|
29 |
+
for key in paths.keys():
|
30 |
+
new_path = [x for x in paths[key] if x != -1]
|
31 |
+
new_path_dict[key] = new_path
|
32 |
+
|
33 |
+
# collect all values of the path without the first element
|
34 |
+
candidates_ids = []
|
35 |
+
for key in new_path_dict.keys():
|
36 |
+
candidates_ids.extend(new_path_dict[key][1:])
|
37 |
+
candidates_ids.extend(ans_ids)
|
38 |
+
candidates_ids = list(set(candidates_ids))
|
39 |
+
|
40 |
+
# get the bm25 score
|
41 |
+
bm_score_dict = bm25.score(query, q_id, candidate_ids=candidates_ids)
|
42 |
+
outputs[i]['bm_score_dict'] = bm_score_dict
|
43 |
+
|
44 |
+
# replace -1 in the bm_vector_dict with the bm_score
|
45 |
+
bm_vector_dict = outputs[i]['bm_vector_dict']
|
46 |
+
for key in bm_vector_dict.keys():
|
47 |
+
if -1 in bm_vector_dict[key]:
|
48 |
+
path = new_path_dict[key]
|
49 |
+
assert len(path) == len(bm_vector_dict[key])
|
50 |
+
|
51 |
+
bm_vector_dict[key] = [bm_score_dict[path[j]] if x == -1 else x for j, x in enumerate(bm_vector_dict[key])]
|
52 |
+
|
53 |
+
|
54 |
+
outputs[i]['bm_vector_dict'] = bm_vector_dict
|
55 |
+
|
56 |
+
# fix length of paths in prime
|
57 |
+
if dataset_name == 'prime':
|
58 |
+
max_len = 3
|
59 |
+
new_paths = {}
|
60 |
+
for key in paths:
|
61 |
+
new_path = paths[key]
|
62 |
+
if len(paths[key]) < max_len:
|
63 |
+
new_path = [-1] * (max_len - len(paths[key])) + paths[key]
|
64 |
+
elif len(paths[key]) > max_len:
|
65 |
+
new_path = paths[key][-max_len:]
|
66 |
+
new_paths[key] = new_path
|
67 |
+
|
68 |
+
# assign the new path to the paths
|
69 |
+
outputs[i]['paths'] = new_paths
|
70 |
+
|
71 |
+
new_outputs.append(outputs[i])
|
72 |
+
|
73 |
+
return new_outputs
|
74 |
+
|
75 |
+
|
76 |
+
def prepare_score_vector_dict(raw_data):
|
77 |
+
# make the score_vector_dict: [bm_score, bm_score, bm_score, ada_score/contriver_score]
|
78 |
+
for i in range(len(raw_data)):
|
79 |
+
# get the pred_dict
|
80 |
+
pred_dict = raw_data[i]['pred_dict']
|
81 |
+
# get the bm_vector_dict
|
82 |
+
bm_vector_dict = raw_data[i]['bm_vector_dict']
|
83 |
+
# initialize the score_vector_dict
|
84 |
+
raw_data[i]['score_vector_dict'] = {}
|
85 |
+
# add the value of pred_dict to the end of the bm_vector_dict
|
86 |
+
for key in pred_dict:
|
87 |
+
# get the bm_score, last element of the bm_vector_dict
|
88 |
+
bm_vector = bm_vector_dict[key]
|
89 |
+
# get the ranking score
|
90 |
+
rk_score = pred_dict[key]
|
91 |
+
# make the score_vector_dict
|
92 |
+
score_vector = bm_vector + [rk_score]
|
93 |
+
# check the length of the score_vector, if less than 4, pad with 0 at the beginning
|
94 |
+
if len(score_vector) < 4:
|
95 |
+
score_vector = [0] * (4 - len(score_vector)) + score_vector
|
96 |
+
elif len(score_vector) > 4:
|
97 |
+
score_vector = score_vector[-4:]
|
98 |
+
# make the score_vector_dict
|
99 |
+
raw_data[i]['score_vector_dict'][key] = score_vector
|
100 |
+
|
101 |
+
return raw_data
|
102 |
+
|
103 |
+
|
104 |
+
def prepare_text_emb_symb_enc(raw_data, skb):
|
105 |
+
# add the text_emb to the raw_data
|
106 |
+
text2emb_list = []
|
107 |
+
text2emb_dict = {}
|
108 |
+
|
109 |
+
symbolic_encode_dict = {
|
110 |
+
3: [0, 1, 1],
|
111 |
+
2: [2, 0, 1],
|
112 |
+
1: [2, 2, 0],
|
113 |
+
}
|
114 |
+
|
115 |
+
for i in range(len(raw_data)):
|
116 |
+
# get the paths
|
117 |
+
paths = raw_data[i]['paths']
|
118 |
+
preds = raw_data[i]['pred_dict']
|
119 |
+
assert len(paths) == len(preds)
|
120 |
+
|
121 |
+
# initialize the text_emb_dict
|
122 |
+
raw_data[i]['text_emb_dict'] = {}
|
123 |
+
|
124 |
+
# initialize the symb_enc_dict
|
125 |
+
raw_data[i]['symb_enc_dict'] = {}
|
126 |
+
|
127 |
+
for key in paths:
|
128 |
+
# get the path
|
129 |
+
path = paths[key]
|
130 |
+
# make uniquee text_emb_path and make dict
|
131 |
+
text_path_li = [skb.get_node_type_by_id(node_id) if node_id != -1 else "padding" for node_id in path]
|
132 |
+
text_path_str = " ".join(text_path_li)
|
133 |
+
if text_path_str not in text2emb_list:
|
134 |
+
|
135 |
+
text2emb_list.append(text_path_str)
|
136 |
+
text2emb_dict[text_path_str] = -1
|
137 |
+
|
138 |
+
# assgin thte text_path to the raw_data
|
139 |
+
raw_data[i]['text_emb_dict'][key] = text_path_str
|
140 |
+
|
141 |
+
# ***** make the symb_enc_dict *****
|
142 |
+
# number of non -1 in the path
|
143 |
+
num_non_1 = len([p for p in path if p != -1])
|
144 |
+
# get the symbolic encoding
|
145 |
+
symb_enc = symbolic_encode_dict[num_non_1]
|
146 |
+
# make the symb_enc_dict
|
147 |
+
raw_data[i]['symb_enc_dict'][key] = symb_enc
|
148 |
+
|
149 |
+
# ***** get the text2emb_dict embeddings *****
|
150 |
+
for key in text2emb_dict.keys():
|
151 |
+
# get the tokens for the node type using th tokenizer
|
152 |
+
text_enc = tokenizer(key, return_tensors='pt')['input_ids']
|
153 |
+
outputs = encoder(text_enc)
|
154 |
+
last_hidden_states = outputs.last_hidden_state.mean(dim=1)
|
155 |
+
text2emb_dict[key] = last_hidden_states.detach()
|
156 |
+
|
157 |
+
|
158 |
+
new_data = {'data': raw_data, 'text2emb_dict': text2emb_dict}
|
159 |
+
|
160 |
+
return new_data
|
161 |
+
|
162 |
+
|
163 |
+
def prepare_trajectories(dataset_name, bm25, skb, outputs):
|
164 |
+
# get the bm25 scores
|
165 |
+
new_outputs = get_bm25_scores(dataset_name, bm25, outputs) # return list
|
166 |
+
# prepare the score_vector_dict
|
167 |
+
new_outputs = prepare_score_vector_dict(new_outputs) # return list
|
168 |
+
# prepare the text_emb and symb_enc_dict
|
169 |
+
new_data = prepare_text_emb_symb_enc(new_outputs, skb) # return dict
|
170 |
+
|
171 |
+
return new_data
|
172 |
+
|
173 |
+
|
174 |
+
def get_contriever_scores(dataset_name, mod, skb, path):
|
175 |
+
|
176 |
+
with open(path, 'rb') as f:
|
177 |
+
data = pkl.load(f)
|
178 |
+
|
179 |
+
raw_data = data['data']
|
180 |
+
|
181 |
+
|
182 |
+
qa = load_qa(dataset_name, human_generated_eval=False)
|
183 |
+
|
184 |
+
contriever = Contriever(skb, dataset_name, device='cuda')
|
185 |
+
|
186 |
+
split_idx = qa.get_idx_split(test_ratio=1.0)
|
187 |
+
|
188 |
+
all_indices = split_idx[mod].tolist()
|
189 |
+
# use tqdm to visualize the progress
|
190 |
+
for idx, i in enumerate(tqdm(all_indices)):
|
191 |
+
query, q_id, ans_ids, _ = qa[i]
|
192 |
+
assert query == raw_data[idx]['query']
|
193 |
+
pred_ids = list(raw_data[idx]['pred_dict'].keys())
|
194 |
+
candidates_ids = list(set(pred_ids))
|
195 |
+
candidates_ids.extend(ans_ids)
|
196 |
+
|
197 |
+
# get contriever score
|
198 |
+
contriever_score_dict = contriever.score(query, q_id, candidate_ids=candidates_ids)
|
199 |
+
|
200 |
+
raw_data[idx]['contriever_score_dict'] = contriever_score_dict
|
201 |
+
|
202 |
+
|
203 |
+
data['data'] = raw_data
|
204 |
+
|
205 |
+
with open(path, 'wb') as f:
|
206 |
+
pkl.dump(data, f)
|
207 |
+
|
208 |
+
def get_ada_scores(dataset_name, mod, skb, path):
|
209 |
+
|
210 |
+
with open(path, 'rb') as f:
|
211 |
+
data = pkl.load(f)
|
212 |
+
|
213 |
+
raw_data = data['data']
|
214 |
+
|
215 |
+
|
216 |
+
qa = load_qa(dataset_name, human_generated_eval=False)
|
217 |
+
|
218 |
+
ada = Ada(skb, dataset_name, device='cuda')
|
219 |
+
|
220 |
+
split_idx = qa.get_idx_split(test_ratio=1.0)
|
221 |
+
|
222 |
+
all_indices = split_idx[mod].tolist()
|
223 |
+
# use tqdm to visualize the progress
|
224 |
+
for idx, i in enumerate(tqdm(all_indices)):
|
225 |
+
query, q_id, ans_ids, _ = qa[i]
|
226 |
+
assert query == raw_data[idx]['query']
|
227 |
+
pred_ids = list(raw_data[idx]['pred_dict'].keys())
|
228 |
+
candidates_ids = list(set(pred_ids))
|
229 |
+
candidates_ids.extend(ans_ids)
|
230 |
+
|
231 |
+
# get ada score
|
232 |
+
ada_score_dict = ada.score(query, q_id, candidate_ids=candidates_ids)
|
233 |
+
|
234 |
+
raw_data[idx]['ada_score_dict'] = ada_score_dict
|
235 |
+
|
236 |
+
|
237 |
+
data['data'] = raw_data
|
238 |
+
|
239 |
+
with open(path, 'wb') as f:
|
240 |
+
pkl.dump(data, f)
|
241 |
+
|
242 |
+
if __name__ == '__main__':
|
243 |
+
print(f"Test prepare_rerank")
|
244 |
+
|
245 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anthropic==0.45.2
|
2 |
+
beautifulsoup4==4.13.3
|
3 |
+
bm25s==0.2.5
|
4 |
+
Colbert==0.40
|
5 |
+
colbert_ai==0.2.21
|
6 |
+
contractions==0.1.73
|
7 |
+
datasets==3.1.0
|
8 |
+
gdown==5.2.0
|
9 |
+
gritlm==1.0.2
|
10 |
+
huggingface_hub==0.26.3
|
11 |
+
langchain==0.3.18
|
12 |
+
langdetect==1.0.9
|
13 |
+
llm2vec==0.2.3
|
14 |
+
nltk==3.9.1
|
15 |
+
numpy==2.2.3
|
16 |
+
ogb==1.3.6
|
17 |
+
openai==1.63.0
|
18 |
+
pandas==2.2.3
|
19 |
+
PyTDC==1.1.1
|
20 |
+
scikit_learn==1.2.2
|
21 |
+
sentence_transformers==3.3.1
|
22 |
+
torch==2.5.1+cu124
|
23 |
+
torch_geometric==2.6.1
|
24 |
+
torch_scatter==2.1.2+pt25cu124
|
25 |
+
torchmetrics==1.6.0
|
26 |
+
tqdm==4.67.1
|
27 |
+
transformers==4.46.3
|
28 |
+
trl==0.12.1
|
29 |
+
unsloth==2025.1.8
|
30 |
+
voyageai==0.3.2
|
31 |
+
wandb==0.19.6
|
run_reasoning.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define datasets and mods
|
4 |
+
datasets=("prime")
|
5 |
+
mods=("test" "val" "train")
|
6 |
+
|
7 |
+
# Define scorer_name mapping using an associative array
|
8 |
+
declare -A dataset_scorer_map=(
|
9 |
+
[mag]="ada"
|
10 |
+
[amazon]="ada"
|
11 |
+
[prime]="contriever"
|
12 |
+
)
|
13 |
+
|
14 |
+
# Loop through datasets and mods
|
15 |
+
for dataset in "${datasets[@]}"; do
|
16 |
+
# Get the corresponding scorer_name for the dataset
|
17 |
+
scorer_name="${dataset_scorer_map[$dataset]}"
|
18 |
+
|
19 |
+
for mod in "${mods[@]}"; do
|
20 |
+
echo "Processing dataset: $dataset with mod: $mod and scorer: $scorer_name"
|
21 |
+
python eval.py --dataset_name "$dataset" --scorer_name "$scorer_name" --mod "$mod"
|
22 |
+
done
|
23 |
+
done
|
train_planner.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Navigate to the Planning directory
|
4 |
+
cd Planning
|
5 |
+
|
6 |
+
# Run the training script
|
7 |
+
python train_eval.py
|
train_reranker.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Navigate to the Reranking directory
|
4 |
+
cd Reranking
|
5 |
+
|
6 |
+
# Run the training script
|
7 |
+
# amazon
|
8 |
+
python train_eval_path_amazon.py
|
9 |
+
# # mag
|
10 |
+
python train_eval_path_mag.py
|
11 |
+
# prime
|
12 |
+
python train_eval_path_prime.py
|
13 |
+
|