GagaLey commited on
Commit
f96a150
1 Parent(s): f677f59
Files changed (10) hide show
  1. README.md +78 -0
  2. eval.py +143 -0
  3. eval_mor.sh +21 -0
  4. get_emb.py +142 -0
  5. mor_env.yml +327 -0
  6. prepare_rerank.py +245 -0
  7. requirements.txt +31 -0
  8. run_reasoning.sh +23 -0
  9. train_planner.sh +7 -0
  10. train_reranker.sh +13 -0
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MoR
2
+
3
+ # Running the Evaluation and Reranking Script
4
+
5
+ ## Installation
6
+ To set up the environment, you can install dependencies using Conda or pip:
7
+
8
+ ### Using Conda
9
+ ```bash
10
+ conda env create -f mor_env.yml
11
+ conda activate your_env_name # Replace with actual environment name
12
+ ```
13
+
14
+ ### Using pip
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ### Checkpoints and embeddings download
20
+ Before running the inference, please go to https://drive.google.com/drive/folders/1ldOYiyrIaZ3AVAKAmNeP0ZWfD3DLZu9D?usp=drive_link
21
+
22
+ (1) download the "checkpoints" and put it under the directory MoR/Planning/
23
+
24
+ (2) download the "data" and put it under the directory MoR/Reasoning/
25
+
26
+ (2) download the "model_checkpoint" and put it under the directory MoR/Reasoning/text_retrievers/
27
+
28
+
29
+ ## Inference
30
+ To run the inference script, execute the following command in the terminal:
31
+
32
+ ```bash
33
+ bash eval_mor.sh
34
+ ```
35
+
36
+ This script will automatically process three datasets using the pre-trained planning graph generator and the pre-trained reranker.
37
+
38
+ ## Training (Train MoR from Scratch)
39
+ ### Step1: Training the planning graph generator
40
+
41
+ ```bash
42
+ bash train_planner.sh
43
+ ```
44
+
45
+ ### Step2: Train mixed traversal to collect candidates (note: there is no training process for reasoning)
46
+
47
+ ```bash
48
+ bash run_reasoning.sh
49
+ ```
50
+
51
+ ### Step3: Training the reranker
52
+
53
+ ```bash
54
+ bash train_reranker.sh
55
+ ```
56
+
57
+ ## Generating training data of Planner
58
+ ### We provide codes to generate your own training data to finetune the Planner by using different LLMs.
59
+ #### If you are using Azure API
60
+
61
+ ```bash
62
+ python script.py --model "model_name" \
63
+ --dataset_name "dataset_name" \
64
+ --azure_api_key "your_azure_key" \
65
+ --azure_endpoint "your_azure_endpoint" \
66
+ --azure_api_version "your_azure_version"
67
+
68
+ ```
69
+
70
+ #### If you are using OpenAI API
71
+
72
+ ```bash
73
+ python script.py --model "model_name" \
74
+ --dataset_name "dataset_name" \
75
+ --openai_api_key "your_openai_key" \
76
+ --openai_endpoint "your_openai_endpoint"
77
+
78
+ ```
eval.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+
4
+
5
+ from Reasoning.mor4path import MOR4Path
6
+ from Planning.model import Planner
7
+ from prepare_rerank import prepare_trajectories
8
+ from tqdm import tqdm
9
+ import os
10
+ import pickle as pkl
11
+ import torch
12
+ import numpy as np
13
+ import pandas as pd
14
+ from argparse import ArgumentParser
15
+ from stark_qa import load_qa, load_skb
16
+ import torch.nn as nn
17
+
18
+
19
+
20
+ # make model_name a argument
21
+ parser = ArgumentParser()
22
+ parser.add_argument("--dataset_name", type=str, default="mag")
23
+ # text retriever name
24
+ parser.add_argument("--text_retriever_name", type=str, default="bm25")
25
+ parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") # contriever for prime, ada for amazon and mag
26
+ # mod
27
+ parser.add_argument("--mod", type=str, default="test", help="train, valid, test")
28
+ # device
29
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
30
+
31
+
32
+
33
+
34
+ if __name__ == "__main__":
35
+
36
+ args = parser.parse_args()
37
+ dataset_name = args.dataset_name
38
+ scorer_name = args.scorer_name
39
+ text_retriever_name = args.text_retriever_name
40
+ skb = load_skb(dataset_name)
41
+ qa = load_qa(dataset_name, human_generated_eval=False)
42
+
43
+ eval_metrics = [
44
+ "mrr",
45
+ "map",
46
+ "rprecision",
47
+ "recall@5",
48
+ "recall@10",
49
+ "recall@20",
50
+ "recall@50",
51
+ "recall@100",
52
+ "hit@1",
53
+ "hit@3",
54
+ "hit@5",
55
+ "hit@10",
56
+ "hit@20",
57
+ "hit@50",
58
+ ]
59
+
60
+ mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb)
61
+ reasoner = Planner(dataset_name)
62
+ outputs = []
63
+ topk = 100
64
+ split_idx = qa.get_idx_split(test_ratio=1.0)
65
+ mod = args.mod
66
+ all_indices = split_idx[mod].tolist()
67
+ eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
68
+
69
+ count = 0
70
+
71
+ # ***** planning *****
72
+ # if the plan cache exists, load it
73
+ plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
74
+ if os.path.exists(plan_cache_path):
75
+ with open(plan_cache_path, 'rb') as f:
76
+ plan_output_list = pkl.load(f)
77
+ else:
78
+ plan_output_list = []
79
+ for idx, i in enumerate(tqdm(all_indices)):
80
+ plan_output = {}
81
+ query, q_id, ans_ids, _ = qa[i]
82
+ rg = reasoner(query)
83
+
84
+ plan_output['query'] = query
85
+ plan_output['q_id'] = q_id
86
+ plan_output['ans_ids'] = ans_ids
87
+ plan_output['rg'] = rg
88
+ plan_output_list.append(plan_output)
89
+ # save plan_output_list
90
+ plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
91
+ os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True)
92
+ with open(plan_cache_path, 'wb') as f:
93
+ pkl.dump(plan_output_list, f)
94
+
95
+
96
+ # ***** Reasoning *****
97
+ for idx, i in enumerate(tqdm(all_indices)):
98
+
99
+ query = plan_output_list[idx]['query']
100
+ q_id = plan_output_list[idx]['q_id']
101
+ ans_ids = plan_output_list[idx]['ans_ids']
102
+ rg = plan_output_list[idx]['rg']
103
+
104
+
105
+ output = mor_path(query, q_id, ans_ids, rg, args)
106
+
107
+ ans_ids = torch.LongTensor(ans_ids)
108
+
109
+ pred_dict = output['pred_dict']
110
+ result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics)
111
+
112
+ result["idx"], result["query_id"] = i, q_id
113
+ result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
114
+ torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
115
+ :topk
116
+ ]
117
+ ].tolist()
118
+
119
+ eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
120
+
121
+ output['q_id'] = q_id
122
+ outputs.append(output)
123
+
124
+ count += 1
125
+
126
+
127
+ # for metric in eval_metrics:
128
+ # print(
129
+ # f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(all_indices)][metric])}"
130
+ # )
131
+
132
+
133
+ print(f"MOR count: {mor_path.mor_count}")
134
+
135
+
136
+ # prepare trajectories and save
137
+ bm25 = mor_path.text_retriever
138
+ test_data = prepare_trajectories(dataset_name, bm25, skb, outputs)
139
+ save_path = f"{dataset_name}_{mod}.pkl"
140
+ with open(save_path, 'wb') as f:
141
+ pkl.dump(test_data, f)
142
+
143
+
eval_mor.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/bin/bash
2
+
3
+
4
+ datasets=("mag" "amazon" "prime")
5
+ # Define scorer_name mapping using an associative array
6
+ declare -A dataset_scorer_map=(
7
+ [mag]="ada"
8
+ [amazon]="ada"
9
+ [prime]="contriever"
10
+ )
11
+
12
+ for dataset in "${datasets[@]}"; do
13
+ # Get the corresponding scorer_name for the dataset
14
+ scorer_name="${dataset_scorer_map[$dataset]}"
15
+ echo "Processing dataset: $dataset with scorer: $scorer_name"
16
+ python eval.py --dataset_name "$dataset" --scorer_name "$scorer_name" --mod "test"
17
+
18
+ cd Reranking
19
+ python rerank.py --dataset_name "$dataset"
20
+ cd ..
21
+ done
get_emb.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ import sys
5
+ import argparse
6
+ import pandas as pd
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings
12
+
13
+ sys.path.append('.')
14
+ from stark_qa import load_skb, load_qa
15
+ from stark_qa.tools.api import get_api_embeddings
16
+ from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
17
+ from models.model import get_embeddings
18
+
19
+ import argparse
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser()
23
+
24
+ # Dataset and embedding model selection
25
+ parser.add_argument('--dataset', default='prime', choices=['amazon', 'prime', 'mag'])
26
+ parser.add_argument('--emb_model', default='contriever',
27
+ choices=[
28
+ 'text-embedding-ada-002',
29
+ 'text-embedding-3-small',
30
+ 'text-embedding-3-large',
31
+ 'voyage-large-2-instruct',
32
+ 'GritLM/GritLM-7B',
33
+ 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp',
34
+ 'all-mpnet-base-v2' # for sentence transformer
35
+ ]
36
+ )
37
+
38
+ # Mode settings
39
+ parser.add_argument('--mode', default='query', choices=['doc', 'query'])
40
+
41
+ # Path settings
42
+ parser.add_argument("--data_dir", default="data/", type=str)
43
+ parser.add_argument("--emb_dir", default="emb/", type=str)
44
+
45
+ # Text settings
46
+ parser.add_argument('--add_rel', action='store_true', default=False, help='add relation to the text')
47
+ parser.add_argument('--compact', action='store_true', default=False, help='make the text compact when input to the model')
48
+
49
+ # Evaluation settings
50
+ parser.add_argument("--human_generated_eval", action="store_true", help="if mode is `query`, then generating query embeddings on human generated evaluation split")
51
+
52
+ # Batch and node settings
53
+ parser.add_argument("--batch_size", default=1024, type=int)
54
+
55
+ # encode kwargs
56
+ parser.add_argument("--n_max_nodes", default=None, type=int, metavar="ENCODE")
57
+ parser.add_argument("--device", default=None, type=str, metavar="ENCODE")
58
+ parser.add_argument("--peft_model_name", default=None, type=str, help="llm2vec pdft model", metavar="ENCODE")
59
+ parser.add_argument("--instruction", type=str, help="gritl/llm2vec instruction", metavar="ENCODE")
60
+
61
+ args = parser.parse_args()
62
+
63
+ # Create encode_kwargs based on the custom metavar "ENCODE"
64
+ encode_kwargs = {k: v for k, v in vars(args).items() if v is not None and parser._option_string_actions[f'--{k}'].metavar == "ENCODE"}
65
+
66
+ return args, encode_kwargs
67
+
68
+
69
+ if __name__ == '__main__':
70
+ args, encode_kwargs = parse_args()
71
+ args.human_generated_eval = False
72
+ mode_surfix = '_human_generated_eval' if args.human_generated_eval and args.mode == 'query' else ''
73
+ mode_surfix += '_no_rel' if not args.add_rel else ''
74
+ mode_surfix += '_no_compact' if not args.compact else ''
75
+ emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f'{args.mode}{mode_surfix}')
76
+ csv_cache = osp.join(args.data_dir, args.dataset, f'{args.mode}{mode_surfix}.csv')
77
+
78
+ print(f'Embedding directory: {emb_dir}')
79
+ os.makedirs(emb_dir, exist_ok=True)
80
+ os.makedirs(os.path.dirname(csv_cache), exist_ok=True)
81
+
82
+ if args.mode == 'doc':
83
+ skb = load_skb(args.dataset)
84
+ lst = skb.candidate_ids
85
+ emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt')
86
+ if args.mode == 'query':
87
+ qa_dataset = load_qa(args.dataset, human_generated_eval=args.human_generated_eval)
88
+ lst = [qa_dataset[i][1] for i in range(len(qa_dataset))]
89
+ emb_path = osp.join(emb_dir, f'query_emb_dict.pt')
90
+ random.shuffle(lst)
91
+
92
+ # Load existing embeddings if they exist
93
+ if osp.exists(emb_path):
94
+ emb_dict = torch.load(emb_path)
95
+ exist_emb_indices = list(emb_dict.keys())
96
+ print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}')
97
+ else:
98
+ emb_dict = {}
99
+ exist_emb_indices = []
100
+
101
+ # Load existing document cache if it exists (only for doc mode)
102
+ if args.mode == 'doc' and osp.exists(csv_cache):
103
+ df = pd.read_csv(csv_cache)
104
+ cache_dict = dict(zip(df['index'], df['text']))
105
+
106
+ # Ensure that the indices in the cache match the expected indices
107
+ assert set(cache_dict.keys()) == set(lst), 'Indices in cache do not match the candidate indices.'
108
+
109
+ indices = list(set(lst) - set(exist_emb_indices))
110
+ texts = [cache_dict[idx] for idx in tqdm(indices, desc="Filtering docs for new embeddings")]
111
+ else:
112
+ indices = lst
113
+ texts = [qa_dataset.get_query_by_qid(idx) if args.mode == 'query'
114
+ else skb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact) for idx in tqdm(indices, desc="Gathering docs")]
115
+ if args.mode == 'doc':
116
+ df = pd.DataFrame({'index': indices, 'text': texts})
117
+ df.to_csv(csv_cache, index=False)
118
+
119
+ print(f'Generating embeddings for {len(texts)} texts...')
120
+ if args.emb_model == 'contriever':
121
+ encoder, tokenizer = get_contriever(dataset_name=args.dataset)
122
+ for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
123
+ batch_texts = texts[i:i+args.batch_size]
124
+ batch_embs = get_contriever_embeddings(batch_texts, encoder=encoder, tokenizer=tokenizer, device='cuda')
125
+ batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
126
+
127
+ batch_indices = indices[i:i+args.batch_size]
128
+ for idx, emb in zip(batch_indices, batch_embs):
129
+ emb_dict[idx] = emb.view(1, -1)
130
+ else:
131
+
132
+ for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"):
133
+ batch_texts = texts[i:i+args.batch_size]
134
+ batch_embs = get_embeddings(batch_texts, args.emb_model, **encode_kwargs)
135
+ batch_embs = batch_embs.view(len(batch_texts), -1).cpu()
136
+
137
+ batch_indices = indices[i:i+args.batch_size]
138
+ for idx, emb in zip(batch_indices, batch_embs):
139
+ emb_dict[idx] = emb.view(1, -1)
140
+
141
+ torch.save(emb_dict, emb_path)
142
+ print(f'Saved {len(emb_dict)} embeddings to {emb_path}!')
mor_env.yml ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: your_env_name
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_6
8
+ - ca-certificates=2024.11.26=h06a4308_0
9
+ - debugpy=1.6.7=py311h6a678d5_0
10
+ - decorator=5.1.1=pyhd3eb1b0_0
11
+ - ipykernel=6.29.5=py311h06a4308_0
12
+ - jedi=0.19.2=py311h06a4308_0
13
+ - jupyter_client=8.6.0=py311h06a4308_0
14
+ - jupyter_core=5.7.2=py311h06a4308_0
15
+ - ld_impl_linux-64=2.40=h12ee557_0
16
+ - libffi=3.4.4=h6a678d5_1
17
+ - libgcc-ng=11.2.0=h1234567_1
18
+ - libgomp=11.2.0=h1234567_1
19
+ - libsodium=1.0.18=h7b6447c_0
20
+ - libstdcxx-ng=11.2.0=h1234567_1
21
+ - libuuid=1.41.5=h5eee18b_0
22
+ - ncurses=6.4=h6a678d5_0
23
+ - nest-asyncio=1.6.0=py311h06a4308_0
24
+ - openssl=3.0.15=h5eee18b_0
25
+ - parso=0.8.4=py311h06a4308_0
26
+ - pip=24.2=py311h06a4308_0
27
+ - prompt_toolkit=3.0.43=hd3eb1b0_0
28
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
29
+ - pure_eval=0.2.2=pyhd3eb1b0_0
30
+ - python=3.11.11=he870216_0
31
+ - python-dateutil=2.9.0post0=py311h06a4308_2
32
+ - pyzmq=25.1.2=py311h6a678d5_0
33
+ - readline=8.2=h5eee18b_0
34
+ - setuptools=75.1.0=py311h06a4308_0
35
+ - six=1.16.0=pyhd3eb1b0_1
36
+ - sqlite=3.45.3=h5eee18b_0
37
+ - stack_data=0.2.0=pyhd3eb1b0_0
38
+ - tk=8.6.14=h39e8969_0
39
+ - tornado=6.4.2=py311h5eee18b_0
40
+ - traitlets=5.14.3=py311h06a4308_0
41
+ - typing_extensions=4.11.0=py311h06a4308_0
42
+ - wheel=0.44.0=py311h06a4308_0
43
+ - xz=5.4.6=h5eee18b_1
44
+ - zeromq=4.3.5=h6a678d5_0
45
+ - zlib=1.2.13=h5eee18b_1
46
+ - pip:
47
+ - accelerate==1.1.1
48
+ - aiobotocore==2.15.2
49
+ - aiohappyeyeballs==2.4.4
50
+ - aiohttp==3.11.9
51
+ - aioitertools==0.12.0
52
+ - aiolimiter==1.2.0
53
+ - aiosignal==1.3.1
54
+ - anndata==0.11.1
55
+ - annotated-types==0.7.0
56
+ - anthropic==0.40.0
57
+ - anyascii==0.3.2
58
+ - anyio==4.6.2.post1
59
+ - argon2-cffi==23.1.0
60
+ - argon2-cffi-bindings==21.2.0
61
+ - array-api-compat==1.9.1
62
+ - arrow==1.3.0
63
+ - asttokens==3.0.0
64
+ - async-lru==2.0.4
65
+ - attrs==24.2.0
66
+ - babel==2.16.0
67
+ - backports-tarfile==1.2.0
68
+ - beautifulsoup4==4.12.3
69
+ - biopython==1.84
70
+ - biothings-client==0.3.1
71
+ - bitarray==3.0.0
72
+ - bitsandbytes==0.44.1
73
+ - bleach==6.2.0
74
+ - blinker==1.9.0
75
+ - bm25s==0.2.5
76
+ - botocore==1.35.36
77
+ - bs4==0.0.2
78
+ - cattrs==24.1.2
79
+ - cellxgene-census==1.15.0
80
+ - certifi==2024.8.30
81
+ - cffi==1.17.1
82
+ - charset-normalizer==3.4.0
83
+ - chembl-webresource-client==0.10.9
84
+ - click==8.1.7
85
+ - colbert==0.40
86
+ - colbert-ai==0.2.21
87
+ - comm==0.2.2
88
+ - contourpy==1.3.1
89
+ - contractions==0.1.73
90
+ - cryptography==44.0.0
91
+ - cut-cross-entropy==24.11.4
92
+ - cycler==0.12.1
93
+ - dataclasses==0.6
94
+ - datasets==3.1.0
95
+ - defusedxml==0.7.1
96
+ - dill==0.3.8
97
+ - distro==1.9.0
98
+ - docker-pycreds==0.4.0
99
+ - docstring-parser==0.16
100
+ - docutils==0.21.2
101
+ - easydict==1.13
102
+ - et-xmlfile==2.0.0
103
+ - evaluate==0.4.2
104
+ - executing==2.1.0
105
+ - fastjsonschema==2.21.1
106
+ - filelock==3.13.1
107
+ - flask==3.1.0
108
+ - fonttools==4.55.1
109
+ - fqdn==1.5.1
110
+ - frozenlist==1.5.0
111
+ - fsspec==2024.2.0
112
+ - fuzzywuzzy==0.18.0
113
+ - gdown==5.2.0
114
+ - gget==0.29.0
115
+ - git-python==1.0.3
116
+ - gitdb==4.0.11
117
+ - gitpython==3.1.43
118
+ - greenlet==3.1.1
119
+ - h11==0.14.0
120
+ - h5py==3.12.1
121
+ - hf-transfer==0.1.8
122
+ - httpcore==1.0.7
123
+ - httpx==0.28.0
124
+ - huggingface-hub==0.26.3
125
+ - icalendar==6.1.0
126
+ - idna==3.10
127
+ - importlib-metadata==8.5.0
128
+ - ipython==8.30.0
129
+ - ipywidgets==8.1.5
130
+ - isoduration==20.11.0
131
+ - itsdangerous==2.2.0
132
+ - jaraco-classes==3.4.0
133
+ - jaraco-context==6.0.1
134
+ - jaraco-functools==4.1.0
135
+ - jeepney==0.8.0
136
+ - jinja2==3.1.3
137
+ - jiter==0.8.0
138
+ - jmespath==1.0.1
139
+ - joblib==1.4.2
140
+ - json5==0.10.0
141
+ - jsonpatch==1.33
142
+ - jsonpointer==3.0.0
143
+ - jsonschema==4.23.0
144
+ - jsonschema-specifications==2024.10.1
145
+ - jupyter==1.1.1
146
+ - jupyter-console==6.6.3
147
+ - jupyter-events==0.11.0
148
+ - jupyter-lsp==2.2.5
149
+ - jupyter-server==2.15.0
150
+ - jupyter-server-terminals==0.5.3
151
+ - jupyterlab==4.3.4
152
+ - jupyterlab-pygments==0.3.0
153
+ - jupyterlab-server==2.27.3
154
+ - jupyterlab-widgets==3.0.13
155
+ - keyring==25.5.0
156
+ - kiwisolver==1.4.7
157
+ - langchain==0.3.9
158
+ - langchain-core==0.3.21
159
+ - langchain-text-splitters==0.3.2
160
+ - langdetect==1.0.9
161
+ - langsmith==0.1.147
162
+ - legacy-api-wrap==1.4.1
163
+ - levenshtein==0.26.1
164
+ - lightning-utilities==0.11.9
165
+ - littleutils==0.2.4
166
+ - llvmlite==0.43.0
167
+ - lxml==5.3.0
168
+ - markdown-it-py==3.0.0
169
+ - markupsafe==2.1.5
170
+ - matplotlib==3.9.3
171
+ - matplotlib-inline==0.1.7
172
+ - matplotlib-venn==1.1.1
173
+ - mdurl==0.1.2
174
+ - mistune==3.1.1
175
+ - moleculeace==3.0.0
176
+ - more-itertools==10.5.0
177
+ - mpmath==1.3.0
178
+ - multidict==6.1.0
179
+ - multiprocess==0.70.16
180
+ - mygene==3.2.2
181
+ - mysql-connector-python==9.1.0
182
+ - natsort==8.4.0
183
+ - nbclient==0.10.2
184
+ - nbconvert==7.16.6
185
+ - nbformat==5.10.4
186
+ - networkx==3.2.1
187
+ - nh3==0.2.19
188
+ - ninja==1.11.1.2
189
+ - nltk==3.9.1
190
+ - notebook==7.3.2
191
+ - notebook-shim==0.2.4
192
+ - numba==0.60.0
193
+ - numpy==1.26.4
194
+ - nvidia-cublas-cu12==12.4.5.8
195
+ - nvidia-cuda-cupti-cu12==12.4.127
196
+ - nvidia-cuda-nvrtc-cu12==12.4.127
197
+ - nvidia-cuda-runtime-cu12==12.4.127
198
+ - nvidia-cudnn-cu12==9.1.0.70
199
+ - nvidia-cufft-cu12==11.2.1.3
200
+ - nvidia-curand-cu12==10.3.5.147
201
+ - nvidia-cusolver-cu12==11.6.1.9
202
+ - nvidia-cusparse-cu12==12.3.1.170
203
+ - nvidia-nccl-cu12==2.21.5
204
+ - nvidia-nvjitlink-cu12==12.4.127
205
+ - nvidia-nvtx-cu12==12.4.127
206
+ - ogb==1.3.6
207
+ - openai==1.56.1
208
+ - openpyxl==3.1.5
209
+ - orjson==3.10.12
210
+ - outdated==0.2.2
211
+ - overrides==7.7.0
212
+ - packaging==24.2
213
+ - pandas==2.2.3
214
+ - pandocfilters==1.5.1
215
+ - patsy==1.0.1
216
+ - peft==0.13.2
217
+ - pexpect==4.9.0
218
+ - pillow==10.2.0
219
+ - pkginfo==1.12.0
220
+ - platformdirs==4.3.6
221
+ - prometheus-client==0.21.1
222
+ - prompt-toolkit==3.0.48
223
+ - propcache==0.2.1
224
+ - protobuf==3.20.3
225
+ - psutil==6.1.0
226
+ - pure-eval==0.2.3
227
+ - pyahocorasick==2.1.0
228
+ - pyaml==24.9.0
229
+ - pyarrow==18.1.0
230
+ - pyarrow-hotfix==0.6
231
+ - pycparser==2.22
232
+ - pydantic==2.10.3
233
+ - pydantic-core==2.27.1
234
+ - pygments==2.18.0
235
+ - pynndescent==0.5.13
236
+ - pyparsing==3.2.0
237
+ - pysocks==1.7.1
238
+ - pytdc==1.1.1
239
+ - python-dotenv==1.0.1
240
+ - python-json-logger==3.2.1
241
+ - python-levenshtein==0.26.1
242
+ - pytz==2024.2
243
+ - pyyaml==6.0.2
244
+ - rapidfuzz==3.10.1
245
+ - rdkit==2023.9.6
246
+ - rdkit-pypi==2022.9.5
247
+ - readme-renderer==44.0
248
+ - referencing==0.36.2
249
+ - regex==2024.11.6
250
+ - requests==2.32.3
251
+ - requests-cache==1.2.1
252
+ - requests-toolbelt==1.0.0
253
+ - rfc3339-validator==0.1.4
254
+ - rfc3986==2.0.0
255
+ - rfc3986-validator==0.1.1
256
+ - rich==13.9.4
257
+ - rpds-py==0.22.3
258
+ - s3fs==2024.2.0
259
+ - safetensors==0.4.5
260
+ - scanpy==1.10.4
261
+ - scikit-learn==1.2.2
262
+ - scikit-optimize==0.10.2
263
+ - scipy==1.14.1
264
+ - seaborn==0.13.2
265
+ - secretstorage==3.3.3
266
+ - send2trash==1.8.3
267
+ - sentence-transformers==3.3.1
268
+ - sentencepiece==0.2.0
269
+ - sentry-sdk==2.19.0
270
+ - session-info==1.0.0
271
+ - setproctitle==1.3.4
272
+ - shtab==1.7.1
273
+ - smmap==5.0.1
274
+ - sniffio==1.3.1
275
+ - somacore==1.0.11
276
+ - soupsieve==2.6
277
+ - sqlalchemy==2.0.36
278
+ - stack-data==0.6.3
279
+ - statsmodels==0.14.4
280
+ - stdlib-list==0.11.0
281
+ - sympy==1.13.1
282
+ - tenacity==9.0.0
283
+ - terminado==0.18.1
284
+ - textsearch==0.0.24
285
+ - threadpoolctl==3.5.0
286
+ - tiledb==0.29.1
287
+ - tiledbsoma==1.11.4
288
+ - tinycss2==1.4.0
289
+ - tokenizers==0.20.3
290
+ - torch==2.5.1+cu124
291
+ - torch-geometric==2.6.1
292
+ - torch-scatter==2.1.2+pt25cu124
293
+ - torchaudio==2.5.1+cu124
294
+ - torchmetrics==1.6.0
295
+ - torchvision==0.20.1+cu124
296
+ - tqdm==4.67.1
297
+ - transformers==4.46.3
298
+ - triton==3.1.0
299
+ - trl==0.12.1
300
+ - twine==6.0.1
301
+ - typeguard==4.4.1
302
+ - types-python-dateutil==2.9.0.20241206
303
+ - typing-extensions==4.12.2
304
+ - tyro==0.9.2
305
+ - tzdata==2024.2
306
+ - ujson==5.10.0
307
+ - umap-learn==0.5.7
308
+ - unsloth==2025.1.8
309
+ - unsloth-zoo==2025.1.5
310
+ - uri-template==1.3.0
311
+ - url-normalize==1.4.3
312
+ - urllib3==2.2.3
313
+ - voyageai==0.3.2
314
+ - wandb==0.18.7
315
+ - wcwidth==0.2.13
316
+ - webcolors==24.11.1
317
+ - webencodings==0.5.1
318
+ - websocket-client==1.8.0
319
+ - werkzeug==3.1.3
320
+ - widgetsnbextension==4.0.13
321
+ - wrapt==1.17.0
322
+ - xformers==0.0.28.post3
323
+ - xxhash==3.5.0
324
+ - yapf==0.43.0
325
+ - yarl==1.18.3
326
+ - zipp==3.21.0
327
+ prefix: /home/yongjia/.conda/envs/g_traversal
prepare_rerank.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Reasoning.text_retrievers.contriever import Contriever
2
+ from Reasoning.text_retrievers.ada import Ada
3
+ from stark_qa import load_qa, load_skb
4
+
5
+ import pickle as pkl
6
+ from tqdm import tqdm
7
+ from transformers import BertTokenizer, BertModel
8
+
9
+ model_name = f"bert-base-uncased"
10
+
11
+ tokenizer = BertTokenizer.from_pretrained(model_name)
12
+ encoder = BertModel.from_pretrained(model_name)
13
+
14
+
15
+ def get_bm25_scores(dataset_name, bm25, outputs):
16
+
17
+ new_outputs = []
18
+ # use tqdm to visualize the progress
19
+ for i in range(len(outputs)):
20
+ query, q_id, ans_ids = outputs[i]['query'], outputs[i]['q_id'], outputs[i]['ans_ids']
21
+ paths= outputs[i]['paths']
22
+ rg = outputs[i]['rg']
23
+
24
+ if dataset_name == 'prime':
25
+ new_path_dict = paths
26
+ else:
27
+ # make new path dict and remove the -1 from the path
28
+ new_path_dict = {}
29
+ for key in paths.keys():
30
+ new_path = [x for x in paths[key] if x != -1]
31
+ new_path_dict[key] = new_path
32
+
33
+ # collect all values of the path without the first element
34
+ candidates_ids = []
35
+ for key in new_path_dict.keys():
36
+ candidates_ids.extend(new_path_dict[key][1:])
37
+ candidates_ids.extend(ans_ids)
38
+ candidates_ids = list(set(candidates_ids))
39
+
40
+ # get the bm25 score
41
+ bm_score_dict = bm25.score(query, q_id, candidate_ids=candidates_ids)
42
+ outputs[i]['bm_score_dict'] = bm_score_dict
43
+
44
+ # replace -1 in the bm_vector_dict with the bm_score
45
+ bm_vector_dict = outputs[i]['bm_vector_dict']
46
+ for key in bm_vector_dict.keys():
47
+ if -1 in bm_vector_dict[key]:
48
+ path = new_path_dict[key]
49
+ assert len(path) == len(bm_vector_dict[key])
50
+
51
+ bm_vector_dict[key] = [bm_score_dict[path[j]] if x == -1 else x for j, x in enumerate(bm_vector_dict[key])]
52
+
53
+
54
+ outputs[i]['bm_vector_dict'] = bm_vector_dict
55
+
56
+ # fix length of paths in prime
57
+ if dataset_name == 'prime':
58
+ max_len = 3
59
+ new_paths = {}
60
+ for key in paths:
61
+ new_path = paths[key]
62
+ if len(paths[key]) < max_len:
63
+ new_path = [-1] * (max_len - len(paths[key])) + paths[key]
64
+ elif len(paths[key]) > max_len:
65
+ new_path = paths[key][-max_len:]
66
+ new_paths[key] = new_path
67
+
68
+ # assign the new path to the paths
69
+ outputs[i]['paths'] = new_paths
70
+
71
+ new_outputs.append(outputs[i])
72
+
73
+ return new_outputs
74
+
75
+
76
+ def prepare_score_vector_dict(raw_data):
77
+ # make the score_vector_dict: [bm_score, bm_score, bm_score, ada_score/contriver_score]
78
+ for i in range(len(raw_data)):
79
+ # get the pred_dict
80
+ pred_dict = raw_data[i]['pred_dict']
81
+ # get the bm_vector_dict
82
+ bm_vector_dict = raw_data[i]['bm_vector_dict']
83
+ # initialize the score_vector_dict
84
+ raw_data[i]['score_vector_dict'] = {}
85
+ # add the value of pred_dict to the end of the bm_vector_dict
86
+ for key in pred_dict:
87
+ # get the bm_score, last element of the bm_vector_dict
88
+ bm_vector = bm_vector_dict[key]
89
+ # get the ranking score
90
+ rk_score = pred_dict[key]
91
+ # make the score_vector_dict
92
+ score_vector = bm_vector + [rk_score]
93
+ # check the length of the score_vector, if less than 4, pad with 0 at the beginning
94
+ if len(score_vector) < 4:
95
+ score_vector = [0] * (4 - len(score_vector)) + score_vector
96
+ elif len(score_vector) > 4:
97
+ score_vector = score_vector[-4:]
98
+ # make the score_vector_dict
99
+ raw_data[i]['score_vector_dict'][key] = score_vector
100
+
101
+ return raw_data
102
+
103
+
104
+ def prepare_text_emb_symb_enc(raw_data, skb):
105
+ # add the text_emb to the raw_data
106
+ text2emb_list = []
107
+ text2emb_dict = {}
108
+
109
+ symbolic_encode_dict = {
110
+ 3: [0, 1, 1],
111
+ 2: [2, 0, 1],
112
+ 1: [2, 2, 0],
113
+ }
114
+
115
+ for i in range(len(raw_data)):
116
+ # get the paths
117
+ paths = raw_data[i]['paths']
118
+ preds = raw_data[i]['pred_dict']
119
+ assert len(paths) == len(preds)
120
+
121
+ # initialize the text_emb_dict
122
+ raw_data[i]['text_emb_dict'] = {}
123
+
124
+ # initialize the symb_enc_dict
125
+ raw_data[i]['symb_enc_dict'] = {}
126
+
127
+ for key in paths:
128
+ # get the path
129
+ path = paths[key]
130
+ # make uniquee text_emb_path and make dict
131
+ text_path_li = [skb.get_node_type_by_id(node_id) if node_id != -1 else "padding" for node_id in path]
132
+ text_path_str = " ".join(text_path_li)
133
+ if text_path_str not in text2emb_list:
134
+
135
+ text2emb_list.append(text_path_str)
136
+ text2emb_dict[text_path_str] = -1
137
+
138
+ # assgin thte text_path to the raw_data
139
+ raw_data[i]['text_emb_dict'][key] = text_path_str
140
+
141
+ # ***** make the symb_enc_dict *****
142
+ # number of non -1 in the path
143
+ num_non_1 = len([p for p in path if p != -1])
144
+ # get the symbolic encoding
145
+ symb_enc = symbolic_encode_dict[num_non_1]
146
+ # make the symb_enc_dict
147
+ raw_data[i]['symb_enc_dict'][key] = symb_enc
148
+
149
+ # ***** get the text2emb_dict embeddings *****
150
+ for key in text2emb_dict.keys():
151
+ # get the tokens for the node type using th tokenizer
152
+ text_enc = tokenizer(key, return_tensors='pt')['input_ids']
153
+ outputs = encoder(text_enc)
154
+ last_hidden_states = outputs.last_hidden_state.mean(dim=1)
155
+ text2emb_dict[key] = last_hidden_states.detach()
156
+
157
+
158
+ new_data = {'data': raw_data, 'text2emb_dict': text2emb_dict}
159
+
160
+ return new_data
161
+
162
+
163
+ def prepare_trajectories(dataset_name, bm25, skb, outputs):
164
+ # get the bm25 scores
165
+ new_outputs = get_bm25_scores(dataset_name, bm25, outputs) # return list
166
+ # prepare the score_vector_dict
167
+ new_outputs = prepare_score_vector_dict(new_outputs) # return list
168
+ # prepare the text_emb and symb_enc_dict
169
+ new_data = prepare_text_emb_symb_enc(new_outputs, skb) # return dict
170
+
171
+ return new_data
172
+
173
+
174
+ def get_contriever_scores(dataset_name, mod, skb, path):
175
+
176
+ with open(path, 'rb') as f:
177
+ data = pkl.load(f)
178
+
179
+ raw_data = data['data']
180
+
181
+
182
+ qa = load_qa(dataset_name, human_generated_eval=False)
183
+
184
+ contriever = Contriever(skb, dataset_name, device='cuda')
185
+
186
+ split_idx = qa.get_idx_split(test_ratio=1.0)
187
+
188
+ all_indices = split_idx[mod].tolist()
189
+ # use tqdm to visualize the progress
190
+ for idx, i in enumerate(tqdm(all_indices)):
191
+ query, q_id, ans_ids, _ = qa[i]
192
+ assert query == raw_data[idx]['query']
193
+ pred_ids = list(raw_data[idx]['pred_dict'].keys())
194
+ candidates_ids = list(set(pred_ids))
195
+ candidates_ids.extend(ans_ids)
196
+
197
+ # get contriever score
198
+ contriever_score_dict = contriever.score(query, q_id, candidate_ids=candidates_ids)
199
+
200
+ raw_data[idx]['contriever_score_dict'] = contriever_score_dict
201
+
202
+
203
+ data['data'] = raw_data
204
+
205
+ with open(path, 'wb') as f:
206
+ pkl.dump(data, f)
207
+
208
+ def get_ada_scores(dataset_name, mod, skb, path):
209
+
210
+ with open(path, 'rb') as f:
211
+ data = pkl.load(f)
212
+
213
+ raw_data = data['data']
214
+
215
+
216
+ qa = load_qa(dataset_name, human_generated_eval=False)
217
+
218
+ ada = Ada(skb, dataset_name, device='cuda')
219
+
220
+ split_idx = qa.get_idx_split(test_ratio=1.0)
221
+
222
+ all_indices = split_idx[mod].tolist()
223
+ # use tqdm to visualize the progress
224
+ for idx, i in enumerate(tqdm(all_indices)):
225
+ query, q_id, ans_ids, _ = qa[i]
226
+ assert query == raw_data[idx]['query']
227
+ pred_ids = list(raw_data[idx]['pred_dict'].keys())
228
+ candidates_ids = list(set(pred_ids))
229
+ candidates_ids.extend(ans_ids)
230
+
231
+ # get ada score
232
+ ada_score_dict = ada.score(query, q_id, candidate_ids=candidates_ids)
233
+
234
+ raw_data[idx]['ada_score_dict'] = ada_score_dict
235
+
236
+
237
+ data['data'] = raw_data
238
+
239
+ with open(path, 'wb') as f:
240
+ pkl.dump(data, f)
241
+
242
+ if __name__ == '__main__':
243
+ print(f"Test prepare_rerank")
244
+
245
+
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anthropic==0.45.2
2
+ beautifulsoup4==4.13.3
3
+ bm25s==0.2.5
4
+ Colbert==0.40
5
+ colbert_ai==0.2.21
6
+ contractions==0.1.73
7
+ datasets==3.1.0
8
+ gdown==5.2.0
9
+ gritlm==1.0.2
10
+ huggingface_hub==0.26.3
11
+ langchain==0.3.18
12
+ langdetect==1.0.9
13
+ llm2vec==0.2.3
14
+ nltk==3.9.1
15
+ numpy==2.2.3
16
+ ogb==1.3.6
17
+ openai==1.63.0
18
+ pandas==2.2.3
19
+ PyTDC==1.1.1
20
+ scikit_learn==1.2.2
21
+ sentence_transformers==3.3.1
22
+ torch==2.5.1+cu124
23
+ torch_geometric==2.6.1
24
+ torch_scatter==2.1.2+pt25cu124
25
+ torchmetrics==1.6.0
26
+ tqdm==4.67.1
27
+ transformers==4.46.3
28
+ trl==0.12.1
29
+ unsloth==2025.1.8
30
+ voyageai==0.3.2
31
+ wandb==0.19.6
run_reasoning.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define datasets and mods
4
+ datasets=("prime")
5
+ mods=("test" "val" "train")
6
+
7
+ # Define scorer_name mapping using an associative array
8
+ declare -A dataset_scorer_map=(
9
+ [mag]="ada"
10
+ [amazon]="ada"
11
+ [prime]="contriever"
12
+ )
13
+
14
+ # Loop through datasets and mods
15
+ for dataset in "${datasets[@]}"; do
16
+ # Get the corresponding scorer_name for the dataset
17
+ scorer_name="${dataset_scorer_map[$dataset]}"
18
+
19
+ for mod in "${mods[@]}"; do
20
+ echo "Processing dataset: $dataset with mod: $mod and scorer: $scorer_name"
21
+ python eval.py --dataset_name "$dataset" --scorer_name "$scorer_name" --mod "$mod"
22
+ done
23
+ done
train_planner.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Navigate to the Planning directory
4
+ cd Planning
5
+
6
+ # Run the training script
7
+ python train_eval.py
train_reranker.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Navigate to the Reranking directory
4
+ cd Reranking
5
+
6
+ # Run the training script
7
+ # amazon
8
+ python train_eval_path_amazon.py
9
+ # # mag
10
+ python train_eval_path_mag.py
11
+ # prime
12
+ python train_eval_path_prime.py
13
+