GagaLey commited on
Commit
7bf4b88
1 Parent(s): b75ad03
This view is limited to 50 files because it contains too many changes. 聽 See raw diff
Files changed (50) hide show
  1. Planning/__pycache__/model.cpython-311.pyc +0 -0
  2. Planning/__pycache__/prompts.cpython-311.pyc +0 -0
  3. Planning/__pycache__/utils.cpython-311.pyc +0 -0
  4. Planning/data/finetune/amazon/1000.json +0 -0
  5. Planning/data/finetune/amazon/llama_ft.jsonl +0 -0
  6. Planning/data/finetune/combine_triplets.py +93 -0
  7. Planning/data/finetune/mag/1000.json +0 -0
  8. Planning/data/finetune/mag/llama_ft.jsonl +0 -0
  9. Planning/data/finetune/prime/1000.json +55 -0
  10. Planning/data/finetune/prime/llama_ft.jsonl +0 -0
  11. Planning/data/get_train_data/__pycache__/prompts.cpython-311.pyc +0 -0
  12. Planning/data/get_train_data/get_llm_data.py +237 -0
  13. Planning/data/get_train_data/post_process_data.py +58 -0
  14. Planning/data/get_train_data/prompts.py +274 -0
  15. Planning/data/train_eval.py +223 -0
  16. Planning/model.py +89 -0
  17. Planning/utils.py +4 -0
  18. Reasoning/__pycache__/mor4node.cpython-311.pyc +0 -0
  19. Reasoning/__pycache__/mor4node_copy.cpython-311.pyc +0 -0
  20. Reasoning/__pycache__/mor4path.cpython-311.pyc +0 -0
  21. Reasoning/__pycache__/ptp_mor4node.cpython-311.pyc +0 -0
  22. Reasoning/__pycache__/utils.cpython-311.pyc +0 -0
  23. Reasoning/mor4path.py +435 -0
  24. Reasoning/structural_retriever/__pycache__/stru4path.cpython-311.pyc +0 -0
  25. Reasoning/structural_retriever/stru4path.py +305 -0
  26. Reasoning/text_retrievers/__init__.py +4 -0
  27. Reasoning/text_retrievers/__pycache__/__init__.cpython-311.pyc +0 -0
  28. Reasoning/text_retrievers/__pycache__/ada.cpython-311.pyc +0 -0
  29. Reasoning/text_retrievers/__pycache__/bm25.cpython-311.pyc +0 -0
  30. Reasoning/text_retrievers/__pycache__/contriever.cpython-311.pyc +0 -0
  31. Reasoning/text_retrievers/__pycache__/stark_model.cpython-311.pyc +0 -0
  32. Reasoning/text_retrievers/ada.py +66 -0
  33. Reasoning/text_retrievers/bm25.py +108 -0
  34. Reasoning/text_retrievers/contriever.py +78 -0
  35. Reasoning/text_retrievers/stark_model.py +151 -0
  36. Reasoning/utils.py +116 -0
  37. Reranking/__pycache__/rerank.cpython-311.pyc +0 -0
  38. Reranking/__pycache__/utils.cpython-311.pyc +0 -0
  39. Reranking/data/checkpoints/amazon/best.pth +3 -0
  40. Reranking/data/checkpoints/mag/best.pth +3 -0
  41. Reranking/data/checkpoints/prime/best.pth +3 -0
  42. Reranking/rerank.py +375 -0
  43. Reranking/rerankers/__pycache__/node.cpython-311.pyc +0 -0
  44. Reranking/rerankers/__pycache__/path.cpython-311.pyc +0 -0
  45. Reranking/rerankers/node.py +87 -0
  46. Reranking/rerankers/path.py +107 -0
  47. Reranking/train_eval_path_amazon.py +694 -0
  48. Reranking/train_eval_path_mag.py +694 -0
  49. Reranking/train_eval_path_prime.py +691 -0
  50. Reranking/utils.py +90 -0
Planning/__pycache__/model.cpython-311.pyc ADDED
Binary file (3.89 kB). View file
 
Planning/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
Planning/__pycache__/utils.cpython-311.pyc ADDED
Binary file (493 Bytes). View file
 
Planning/data/finetune/amazon/1000.json ADDED
The diff for this file is too large to render. See raw diff
 
Planning/data/finetune/amazon/llama_ft.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
Planning/data/finetune/combine_triplets.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+
5
+ path = "./prime/1000.json"
6
+ with open(path, 'r') as f:
7
+ data = json.load(f)
8
+
9
+ def combine_triplets_by_boundary_no_loops(NTar):
10
+ combined = NTar[:]
11
+ changes = True # Track if merges happen during an iteration
12
+
13
+ while changes:
14
+ changes = False # Reset changes for this iteration
15
+ new_combined = [] # To hold the merged triplets
16
+ used = [False] * len(combined) # Track which triplets have been processed
17
+
18
+ for i in range(len(combined)):
19
+ if used[i]:
20
+ continue # Skip already processed triplets
21
+
22
+ current_triplet = combined[i] # Convert triplet to a mutable list
23
+ used[i] = True
24
+ merged = False # Track if the current triplet is merged with another
25
+
26
+ for j in range(len(combined)):
27
+ if i != j and not used[j]:
28
+ other_triplet = combined[j]
29
+
30
+ # Check if the current triplet can be merged with another
31
+ if current_triplet[-1] == other_triplet[0]: # Current's last matches other's first
32
+ current_triplet.extend(other_triplet[1:]) # Merge, excluding duplicate entity
33
+ used[j] = True
34
+ merged = True
35
+ changes = True # A merge occurred
36
+ break
37
+ elif current_triplet[0] == other_triplet[-1]: # Current's first matches other's last
38
+ current_triplet = other_triplet[:-1] + current_triplet # Merge, excluding duplicate entity
39
+ used[j] = True
40
+ merged = True
41
+ changes = True
42
+ break
43
+
44
+ # After merging or if no merge happened, add the triplet to the new list
45
+ new_combined.append(current_triplet)
46
+
47
+ # Update the combined list with the newly merged triplets
48
+ combined = new_combined
49
+
50
+ return combined
51
+
52
+
53
+ def combine_tar_ntar(tar, ntar):
54
+ for nt in ntar:
55
+ for i in range(len(tar)):
56
+ if tar[i][-1] == nt[0]:
57
+ tar[i] = tar[i]+nt[1:]
58
+ elif tar[i][0] == nt[-1]:
59
+ tar[i] = nt[:-1]+tar[i]
60
+ return tar
61
+
62
+ def check_order(routes, target):
63
+ for i in range(len(routes)):
64
+ if routes[i][-1] != target:
65
+ if routes[i][0] != target:
66
+ raise ValueError(f"Wrong order: {routes[i]}")
67
+ else:
68
+ routes[i] = routes[i][::-1]
69
+ return routes
70
+
71
+ routes_list = []
72
+ restrictions_list = []
73
+ for i in range(len(data)):
74
+ triplets = data[i]['Triplets']
75
+ target = data[i]['Target']
76
+ restrictions = data[i]['Restriction']
77
+ tar = []
78
+ ntar = []
79
+ for tp in triplets:
80
+ if target in tp:
81
+ tar.append(tp)
82
+ else:
83
+ ntar.append(tp)
84
+ if len(ntar) > 0:
85
+ ntar = combine_triplets_by_boundary_no_loops(ntar)
86
+ routes = combine_tar_ntar(tar, ntar)
87
+ else:
88
+ routes = tar
89
+ print(target)
90
+ routes = check_order(routes, target)
91
+ routes_list.append(routes)
92
+ restrictions_list.append(restrictions)
93
+
Planning/data/finetune/mag/1000.json ADDED
The diff for this file is too large to render. See raw diff
 
Planning/data/finetune/mag/llama_ft.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
Planning/data/finetune/prime/1000.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "query": "Is the lens-specific intermediate filament-like protein, filensin, which is encoded by a gene expressed in the lens of camera-type eyes but not found in nasal cavity epithelial cells, involved in providing structural support to the ocular lens?",
4
+ "answer": {
5
+ "Triplets": [
6
+ [
7
+ "anatomy",
8
+ "expression present",
9
+ "gene/protein"
10
+ ],
11
+ [
12
+ "anatomy",
13
+ "expression absent",
14
+ "gene/protein"
15
+ ]
16
+ ],
17
+ "Restriction": {
18
+ "anatomy": [
19
+ "lens of camera-type eyes",
20
+ "nasal cavity epithelial cells"
21
+ ],
22
+ "gene/protein": [
23
+ "filensin"
24
+ ],
25
+ "cellular_component": [
26
+ "ocular lens"
27
+ ]
28
+ },
29
+ "Target": "gene/protein"
30
+ }
31
+ },
32
+ {
33
+ "query": "Which anatomical structures lack the expression of genes or proteins involved in the interaction with the fucose metabolism pathway?",
34
+ "answer": {
35
+ "Triplets": [
36
+ [
37
+ "anatomy",
38
+ "expression absent",
39
+ "gene/protein"
40
+ ],
41
+ [
42
+ "gene/protein",
43
+ "interacts with",
44
+ "pathway"
45
+ ]
46
+ ],
47
+ "Restriction": {
48
+ "pathway": [
49
+ "fucose metabolism"
50
+ ]
51
+ },
52
+ "Target": "anatomy"
53
+ }
54
+ }
55
+ ]
Planning/data/finetune/prime/llama_ft.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
Planning/data/get_train_data/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
Planning/data/get_train_data/get_llm_data.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Desc: This file is used to get the training data from the LLM
3
+
4
+ """
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ # Get the absolute path of the current script
9
+ current_file = Path(__file__).resolve()
10
+ project_root = current_file.parents[3]
11
+
12
+ # Add the project root to the system path
13
+ sys.path.append(str(project_root))
14
+
15
+ from stark_qa import load_qa
16
+
17
+ import argparse
18
+ import os
19
+ from openai import AzureOpenAI
20
+ import json
21
+ import openai
22
+ from prompts import prompts
23
+
24
+
25
+
26
+
27
+
28
+
29
+ """
30
+
31
+ MAG:
32
+ sys_content: 478/query
33
+ output: 45/query
34
+ input: 25/query
35
+ 1000 queries
36
+
37
+ total price:
38
+ 1. o1: $13.29
39
+ 2. o3mini: $0.97
40
+ 3. deepseek-chat: $0.24
41
+ 4. deepseek-reasoner: $0.49
42
+
43
+ Amazon:
44
+ sys_content: 478/query
45
+
46
+ """
47
+
48
+ # get the prompt for different datasets
49
+ def get_sys_content(dataset_name):
50
+ """
51
+ input:
52
+ dataset_name: the name of the dataset
53
+ output:
54
+ sys_content: the sys_content for the dataset
55
+ """
56
+ sys_content = prompts(dataset_name)
57
+
58
+
59
+ return sys_content
60
+
61
+ # get the response from the llm
62
+ def get_response(sys_content, user_content):
63
+
64
+ messages = [{"role": "system", "content": sys_content},
65
+ {"role": "user", "content": user_content}
66
+ ]
67
+
68
+ chat_completion = client.chat.completions.create(
69
+ messages=messages,
70
+ model=parameters['azure']['model'], # parameters['azure']['model'], parameters['openai']['model']
71
+ # temperature=0,
72
+ seed=576879897,
73
+ )
74
+ response = chat_completion.choices[0].message.content
75
+
76
+ # print(messages)
77
+ # print(response)
78
+
79
+ return response
80
+
81
+ # save the outputs to json file
82
+ def save_json(data, dataset_name):
83
+ """
84
+ input:
85
+ data: the data to be saved
86
+ dataset_name: the name of the dataset
87
+ """
88
+
89
+ file_dir = f"/home/yongjia/dgl/Yongjia/MOE/Reasoner/data/finetune/{dataset_name}"
90
+ os.makedirs(file_dir, exist_ok=True)
91
+ file_path = f"{file_dir}/1000_{parameters['azure']['model']}.json"
92
+
93
+ with open(file_path, 'w') as f:
94
+ json.dump(data, f, indent=4)
95
+ print(f"Saved to {file_path}")
96
+
97
+ # get the reasoning graphs for a dataset
98
+ def get_rg(dataset_name):
99
+ """
100
+ input:
101
+ dataset_name: the name of the dataset
102
+ output:
103
+ rg: the reasoning graph for the dataset
104
+ """
105
+
106
+ # get the prompt for the dataset
107
+ sys_content = get_sys_content(dataset_name)
108
+
109
+ # get qa dataset
110
+ qa = load_qa(dataset_name)
111
+ train_qa = qa.get_subset('train')
112
+
113
+ # we sample 1000 queries from the training set
114
+ pair_list = []
115
+ failure_count = 0
116
+ for i in range(1500):
117
+ query, q_id, ans_ids, _ = train_qa[i]
118
+
119
+ # call the llm to get the reasoning graph
120
+ response = get_response(sys_content, query)
121
+ print(response)
122
+
123
+ # process the response
124
+
125
+ if dataset_name == 'prime':
126
+ output = {
127
+ "Triplets":[],
128
+ "Restriction": [],
129
+ "Target": ""
130
+ }
131
+
132
+ try:
133
+ response = response.split('\n')
134
+ triplets_raw = response[0].replace('Triplets:', '').strip()
135
+ triplets = json.loads(triplets_raw)
136
+ output['Triplets'] = triplets
137
+
138
+ restriction_raw = response[1].replace('Restriction:', '').strip()
139
+ restriction = json.loads(restriction_raw)
140
+ output['Restriction'] = restriction
141
+
142
+ target = response[2].replace('Target:', '').strip()
143
+ output['Target'] = target
144
+ except:
145
+ failure_count += 1
146
+ continue
147
+
148
+ elif dataset_name == 'mag' or dataset_name == 'amazon':
149
+ output = {
150
+ "Metapath": "",
151
+ "Restriction": [],
152
+ }
153
+
154
+ try:
155
+ response = response.split('\n')
156
+ metapath = response[0].replace('Metapath:', '').strip()
157
+ output['Metapath'] = metapath
158
+
159
+ restriction_raw = response[1].replace('Restriction:', '').strip()
160
+ restriction = json.loads(restriction_raw)
161
+ output['Restriction'] = restriction
162
+ except:
163
+ failure_count += 1
164
+ continue
165
+
166
+ else:
167
+ raise ValueError('The dataset is not supported')
168
+
169
+ pair = {'query': query, 'answer': output}
170
+
171
+ pair_list.append(pair)
172
+
173
+ if len(pair_list) == 1000:
174
+ break
175
+
176
+ # save the output to json file
177
+ save_json(pair_list, dataset_name)
178
+ print(f"Failure count: {failure_count}")
179
+
180
+
181
+ if __name__ == '__main__':
182
+ # Argument parser setup
183
+ parser = argparse.ArgumentParser(description="Load LLM parameters and initialize API clients.")
184
+
185
+ # Dataset name
186
+ parser.add_argument("--dataset_name", type=str, required=True,
187
+ choices=["mag", "amazon", "prime"],
188
+ help="Specify the dataset to use.")
189
+
190
+ # Model selection
191
+ parser.add_argument("--model", type=str, required=True,
192
+ choices=["gpt-4o-mini-20240718", "gpt-4o-2024-05-13",
193
+ "deepseek-reasoner", "gpt-o1-2024-12-17",
194
+ "o3-mini-2025-01-31"],
195
+ help="Specify the model to use.")
196
+
197
+ # Azure API parameters
198
+ parser.add_argument("--azure_api_key", type=str, default=None, help="Azure API Key")
199
+ parser.add_argument("--azure_endpoint", type=str, default=None, help="Azure API Endpoint")
200
+ parser.add_argument("--azure_api_version", type=str, default=None, help="Azure API Version")
201
+
202
+ # OpenAI API parameters
203
+ parser.add_argument("--openai_api_key", type=str, default=None, help="OpenAI API Key")
204
+ parser.add_argument("--openai_endpoint", type=str, default=None, help="OpenAI API Endpoint")
205
+
206
+ args = parser.parse_args()
207
+
208
+ # Initialize parameters dictionary
209
+ parameters = {
210
+ "azure": {
211
+ "api_key": args.azure_api_key,
212
+ "azure_endpoint": args.azure_endpoint,
213
+ "api_version": args.azure_api_version,
214
+ },
215
+ "openai": {
216
+ "api_key": args.openai_api_key,
217
+ "endpoint": args.openai_endpoint,
218
+ }
219
+ }
220
+
221
+
222
+ # Determine which API client to use
223
+ if parameters["openai"]["api_key"]:
224
+ client = openai.OpenAI(
225
+ base_url=parameters["openai"]["endpoint"],
226
+ api_key=parameters["openai"]["api_key"],
227
+ )
228
+ else:
229
+ client = AzureOpenAI(
230
+ azure_endpoint=parameters["azure"]["azure_endpoint"],
231
+ api_key=parameters["azure"]["api_key"],
232
+ api_version=parameters["azure"]["api_version"],
233
+ )
234
+
235
+ get_rg(args.dataset_name)
236
+
237
+
Planning/data/get_train_data/post_process_data.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Description: This script is used to post-process the data after getting it from the LLM.
3
+
4
+ input: data from llm
5
+ output: data for llama finetuning
6
+
7
+ """
8
+ import json
9
+ import os
10
+ import copy
11
+
12
+ def save_data(data, dataset_name, model_name):
13
+ file_dir = f"../data/finetune/{dataset_name}"
14
+ os.makedirs(file_dir, exist_ok=True)
15
+ file_path = os.path.join(file_dir, f"llama_ft_{model_name}.jsonl")
16
+
17
+ with open(file_path, "w") as f:
18
+ for d in data:
19
+ f.write(json.dumps(d) + "\n")
20
+
21
+ print(f"Data saved to {file_path}")
22
+
23
+ def process(sample, dataset_name, model_name):
24
+ """
25
+ input: sample from llm
26
+ output: sample for llama finetuning
27
+ """
28
+
29
+ output_format = {"conversations": [{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]}
30
+ ft_list = []
31
+ for i in range(len(sample)):
32
+ output = copy.deepcopy(output_format)
33
+ output["conversations"][0]["content"] = sample[i]["query"]
34
+ output["conversations"][1]["content"] = str(sample[i]["answer"])
35
+ ft_list.append(output)
36
+
37
+ # save data
38
+ save_data(ft_list, dataset_name, model_name)
39
+
40
+
41
+ # ***** Main *****
42
+ if __name__ == "__main__":
43
+ # read data
44
+ dataset_name_list = ["mag"]
45
+ model_names = ["gpt-4o-mini-20240718", "o3-mini-2025-01-31", "gpt-o1-2024-12-17", "gpt-4o-2024-05-13", "gpt-4o-mini-20240718", "gpt35-1106"] # gpt-o1-2024-12-17, "gpt-4o-mini-20240718", "gpt35-1106", o3-mini-2025-01-31
46
+
47
+ for dataset_name in dataset_name_list:
48
+ for model_name in model_names:
49
+ relative_path = f"finetune/{dataset_name}/1000.json" # f"finetune/{dataset_name}/1000_{dataset_name}.json"
50
+ current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Directory of the current script
51
+ file_path = os.path.join(current_dir, relative_path)
52
+ with open(f"{file_path}", "r") as f:
53
+ sample = json.load(f)
54
+
55
+ # process data
56
+ process(sample, dataset_name, model_name)
57
+ print(f"Processing {model_name} for {dataset_name} is done.")
58
+ break
Planning/data/get_train_data/prompts.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def prompts(dataset_name):
2
+
3
+ """
4
+
5
+ input:
6
+ dataset_name: the name of the dataset
7
+ output:
8
+ prompt: the prompt for the dataset
9
+ """
10
+ if dataset_name == "amazon":
11
+ sys_prompt = """
12
+
13
+ You are a reasoning graph finder agent. Your role is to:
14
+ 1. Identify the underlying **meta-path** from a given question, which consists of the **entity types** at each reasoning step.
15
+ 2. Extract the **content restriction** for each **entity type** based on the question. If there is no restriction for an entity type, leave its value empty.
16
+
17
+ You will be provided with a predefined **Entity Type List**. Only use the entity types from this list when constructing the meta-path and restrictions. Your response must be concise and strictly adhere to the specified **Output Format**.
18
+
19
+ """
20
+
21
+ # Define the entity type list
22
+ entity_type_list = """
23
+ Entity Type List:
24
+ - **brand**
25
+ - **category**
26
+ - **product**
27
+
28
+ """
29
+
30
+ demonstrations = """
31
+ Here are some examples:
32
+ **Question 1:**
33
+ What are some 5x4 inch team sports decals by Football Fanatics that are easy to apply on exterior surfaces?
34
+
35
+ **Answer 1:**
36
+ Metapath: brand -> product
37
+ Restriction: {"brand": ["Football Fanatics"], "product": ["5x4 inch team sports decals by Football Fanatics that are easy to apply on exterior surfaces"]}
38
+
39
+ **Question 2:**
40
+ Looking for men's insulated pants suitable for heavy rain, up to 10k water resistance, and with reinforced cuffs specifically for added protection in parking lots. Any suggestions?
41
+
42
+ **Answer 2:**
43
+ Metapath: category -> product
44
+ Restriction: {"category": ["pants"], "product": ["men\'s insulated pants suitable for heavy rain", "reinforced cuffs specifically for added protection in parking lots"]}
45
+
46
+ **Question 3:**
47
+ Can you recommend a dive mask that would work well with the Dive Mask Freediving Mask Spearfishing Mask Low Volume Mini Mask? I need the two masks to be compatible for my diving activities.
48
+
49
+ **Answer 3:**
50
+ Metapath: product -> product
51
+ Restriction: {"product": ["work well with the Dive Mask Freediving Mask Spearfishing Mask Low Volume Mini Mask", "compatible for diving activities", "a dive mask"]}
52
+
53
+ **Question 4:**
54
+ What are some UV protective women's golf jackets from PUMA? I'm cautious about skin protection, but I'm a big fan of PUMA.
55
+
56
+ **Answer 4:**
57
+ Metapath: brand -> product <- category
58
+ Restriction: {"brand": ["PUMA"], "category": ["golf jackets"], "product": ["UV protective women\'s golf jackets from PUMA", "skin protection"]}
59
+
60
+
61
+
62
+ """
63
+
64
+ output_format = """
65
+
66
+ **Output Format**
67
+ Metapath: "",
68
+ Restriction: {}
69
+ """
70
+
71
+ sys_content = sys_prompt + entity_type_list + output_format + demonstrations
72
+
73
+
74
+ elif dataset_name == "mag":
75
+
76
+ sys_prompt = """
77
+ You are a reasoning finder agent. Your role is to:
78
+ 1. Identify the underlying **meta-path** from a given question, which consists of the **entity types** at each reasoning step.
79
+ 2. Extract the **content restriction** for each **entity type** based on the question. If there is no restriction for an entity type, leave its value empty.
80
+
81
+ You will be provided with a predefined **Entity Type List**. Only use the entity types from this list when constructing the meta-path and restrictions. Your response must be concise and strictly adhere to the specified **Output Format**.
82
+
83
+ """
84
+
85
+ entity_type_list = """
86
+ Entity Type List:
87
+ - **paper**
88
+ - **author**
89
+ - **institution**
90
+ - **field_of_study**
91
+
92
+ """
93
+
94
+ output_format = """
95
+ **Output Format**
96
+ Metapath: "",
97
+ Restriction: {}
98
+ """
99
+
100
+
101
+ demonstrations = """
102
+
103
+ Here are some examples:
104
+ **Question 1:**
105
+ Show me research articles on the association of quasi-periodic oscillations (QPOs) with noise in celestial bodies within the context of Bicoherence.
106
+
107
+ **Answer 1:**
108
+ Metapath: field_of_study -> paper
109
+ Restriction: {"field_of_study": ["oscillations", "physics"], "paper": ["association of quasi-periodic oscillations (QPOs) with noise in celestial bodies within the context of Bicoherence"]}
110
+
111
+ **Question 2:**
112
+ What research on water absorption in different frequency ranges have been referenced or deemed significant in the paper entitled 'High-resolution terahertz atmospheric water vapor continuum measurements?
113
+
114
+ **Answer 2:**
115
+ Metapath: paper -> paper
116
+ Restriction: {"paper": ["water absorption in different frequency ranges", "High-resolution terahertz atmospheric water vapor continuum measurements"]}
117
+
118
+ **Question 3:**
119
+ Show me publications by A.J. Turvey on the topic of supersymmetry particle searches.
120
+
121
+ **Answer 3:**
122
+ Metapath: author -> paper
123
+ Restriction: {"author": ["A.J. Turvey"], "paper": ["supersymmetry particle searches"]}
124
+
125
+ **Question 4:**
126
+ Looking for papers co-authored by someone involved in "Strain transferring mechanism analysis of the substrate-bonded FBG sensor". The papers should be in the similar field which is fiber optic strain sensors, and further discuss their development and application.
127
+
128
+ **Answer 4:**
129
+ Metapath: paper -> author -> paper
130
+ Restriction: {"paper": ["Strain transferring mechanism analysis of the substrate-bonded FBG sensor", "in the similar field which is fiber optic strain sensors, and further discuss their development and application"]}
131
+
132
+ **Question 5:**
133
+ Can you find any publications by the authors of "Tradeoffs in the Realization of Electrically Pumped Vertical External Cavity Surface Emitting Lasers," that delve into the topic of hybrid quantum well/quantum dot structures in the context of lasers?
134
+
135
+ **Answer 5:**
136
+ Metapath: paper -> author -> paper <- field_of_study
137
+ Restriction: {"paper": ["Tradeoffs in the Realization of Electrically Pumped Vertical External Cavity Surface Emitting Lasers", "hybrid quantum well/quantum dot structures in the context of lasers"], "field_of_study": ["lasers", "hybrid quantum well/quantum dot", "physics", "Emission spectrum"]}
138
+
139
+ **Question 6:**
140
+ Which publications from Altair Engineering authors focus on improving directional sensitivity across a wide range of frequencies?
141
+
142
+ **Answer 6:**
143
+ Metapath: institution -> author -> paper
144
+ Restriction: {"institution": ["Altair Engineering"], "paper": ["improving directional sensitivity across a wide range of frequencies"]}
145
+
146
+ **Question 7:**
147
+ Publications by Carlisle Companies authors on satellite instrumentation and space performance
148
+
149
+ **Answer 7:**
150
+ Metapath: institution -> author -> paper <- field_of_study
151
+ Restriction: {"institution": ["Carlisle Companies"], "paper": ["satellite instrumentation and space performance"], "field_of_study": ["satellite", 'space performance']}
152
+
153
+ """
154
+
155
+ sys_content = sys_prompt + entity_type_list + output_format + demonstrations
156
+
157
+ elif dataset_name == "prime":
158
+
159
+ sys_prompt = """
160
+ You are a triplets extractor. Given a list of triplets and a query, please
161
+ 1. extract the triplets contained in the query and give a list to me.
162
+ 2. make a restriction list which contains the description of the entity in the query.
163
+ 3. tell me which entity the query is asking for.
164
+ Your response must be concise and strictly adhere to the specified **Output Format**.
165
+
166
+ """
167
+
168
+ triplets = """
169
+ Triplets list:
170
+ [('anatomy', 'expression absent', 'gene/protein'),
171
+ ('anatomy', 'expression present', 'gene/protein'),
172
+ ('anatomy', 'parent-child', 'anatomy'),
173
+ ('biological_process', 'interacts with', 'exposure'),
174
+ ('biological_process', 'interacts with', 'gene/protein'),
175
+ ('biological_process', 'parent-child', 'biological_process'),
176
+ ('cellular_component', 'interacts with', 'exposure'),
177
+ ('cellular_component', 'interacts with', 'gene/protein'),
178
+ ('cellular_component', 'parent-child', 'cellular_component'),
179
+ ('disease', 'associated with', 'gene/protein'),
180
+ ('disease', 'contraindication', 'drug'),
181
+ ('disease', 'indication', 'drug'),
182
+ ('disease', 'linked to', 'exposure'),
183
+ ('disease', 'off-label use', 'drug'),
184
+ ('disease', 'parent-child', 'disease'),
185
+ ('disease', 'phenotype absent', 'effect/phenotype'),
186
+ ('disease', 'phenotype present', 'effect/phenotype'),
187
+ ('drug', 'carrier', 'gene/protein'),
188
+ ('drug', 'contraindication', 'disease'),
189
+ ('drug', 'enzyme', 'gene/protein'),
190
+ ('drug', 'indication', 'disease'),
191
+ ('drug', 'off-label use', 'disease'),
192
+ ('drug', 'side effect', 'effect/phenotype'),
193
+ ('drug', 'synergistic interaction', 'drug'),
194
+ ('drug', 'target', 'gene/protein'),
195
+ ('drug', 'transporter', 'gene/protein'),
196
+ ('effect/phenotype', 'associated with', 'gene/protein'),
197
+ ('effect/phenotype', 'parent-child', 'effect/phenotype'),
198
+ ('effect/phenotype', 'phenotype absent', 'disease'),
199
+ ('effect/phenotype', 'phenotype present', 'disease'),
200
+ ('effect/phenotype', 'side effect', 'drug'),
201
+ ('exposure', 'interacts with', 'biological_process'),
202
+ ('exposure', 'interacts with', 'cellular_component'),
203
+ ('exposure', 'interacts with', 'gene/protein'),
204
+ ('exposure', 'interacts with', 'molecular_function'),
205
+ ('exposure', 'linked to', 'disease'),
206
+ ('exposure', 'parent-child', 'exposure'),
207
+ ('gene/protein', 'associated with', 'disease'),
208
+ ('gene/protein', 'associated with', 'effect/phenotype'),
209
+ ('gene/protein', 'carrier', 'drug'),
210
+ ('gene/protein', 'enzyme', 'drug'),
211
+ ('gene/protein', 'expression absent', 'anatomy'),
212
+ ('gene/protein', 'expression present', 'anatomy'),
213
+ ('gene/protein', 'interacts with', 'biological_process'),
214
+ ('gene/protein', 'interacts with', 'cellular_component'),
215
+ ('gene/protein', 'interacts with', 'exposure'),
216
+ ('gene/protein', 'interacts with', 'molecular_function'),
217
+ ('gene/protein', 'interacts with', 'pathway'),
218
+ ('gene/protein', 'ppi', 'gene/protein'),
219
+ ('gene/protein', 'target', 'drug'),
220
+ ('gene/protein', 'transporter', 'drug'),
221
+ ('molecular_function', 'interacts with', 'exposure'),
222
+ ('molecular_function', 'interacts with', 'gene/protein'),
223
+ ('molecular_function', 'parent-child', 'molecular_function'),
224
+ ('pathway', 'interacts with', 'gene/protein'),
225
+ ('pathway', 'parent-child', 'pathway')]
226
+
227
+ """
228
+
229
+ output_format = """
230
+ **Output Format**
231
+ Triplets: []
232
+ Restriction: {}
233
+ Target: ""
234
+
235
+ """
236
+
237
+ demonstrations = """
238
+ Here are some examples:
239
+ **Question 1:**
240
+ Search for conditions that lack any associated treatment medications and have a connection to the formation of Onion bulb structures.
241
+
242
+ **Answer 1:**
243
+ Triplets: [["effect/phenotype", "phenotype present", "disease"], ["drug", "off-label use", "disease"]]
244
+ Restriction: {'effect/phenotype': ['formation of Onion bulb structures']}
245
+ Target: disease
246
+
247
+ **Question 2:**
248
+ What drug should be avoided for focal hand dystonia and also targets the ORM2 gene/protein?
249
+
250
+ **Answer 2:**
251
+ Triplets: [["disease", "contraindication", "drug"], ["drug", "target", "gene/protein"]]
252
+ Restriction: {"disease": ["focal hand dystonia"], "gene/protein": ["ORM2"]}
253
+ Target: drug
254
+
255
+ **Question 3:**
256
+ Could my hip muscle weakness be a sign of the mitochondrial disease my mother has, or another related condition?
257
+ Triplets: [["disease", "parent-child", "disease"], ["effect/phenotype", "phenotype present", "disease"]]
258
+
259
+ **Answer 3:**
260
+ Restriction: {"anatomy": ["hip muscle"], "disease": ["mitochondrial disease my mother has", "another related condition"], "effect/phenotype": ["weakness"]}
261
+ Target: disease
262
+
263
+ **Question 4:**
264
+ Can you find the genes and proteins involved with the activity of dolichyl-phosphate-mannose-dependent alpha-1,6-mannosyltransferase?
265
+
266
+ **Answer 4:**
267
+ Triplets: [["molecular_function", "interacts with", "gene/protein"]]
268
+ Restriction: {"molecular_function": ["dolichyl-phosphate-mannose-dependent alpha-1,6-mannosyltransferase"]}
269
+ Target: gene/protein
270
+ """
271
+
272
+ sys_content = sys_prompt + triplets + output_format + demonstrations
273
+
274
+ return sys_content
Planning/data/train_eval.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Description: This file is used to train and evaluate the llama model.
3
+
4
+ """
5
+ import os
6
+ from unsloth import FastLanguageModel
7
+ from trl import SFTTrainer
8
+ from transformers import TrainingArguments, DataCollatorForSeq2Seq
9
+ from unsloth import is_bfloat16_supported
10
+ from datasets import load_dataset
11
+ from unsloth.chat_templates import train_on_responses_only
12
+ from unsloth.chat_templates import get_chat_template
13
+ from sklearn.model_selection import train_test_split
14
+ from transformers import TextStreamer
15
+ import ast
16
+ import contractions
17
+ import re
18
+ from utils import remove_inner_single_quotes
19
+
20
+ # *****load model and tokenizer*****
21
+ max_seq_length = 2048
22
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
23
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage.
24
+
25
+ model, tokenizer = FastLanguageModel.from_pretrained(
26
+ model_name = "unsloth/Llama-3.2-3B-Instruct",
27
+ max_seq_length = max_seq_length, # specify the maximum length of input the model can accept
28
+ dtype = dtype,
29
+ load_in_4bit = load_in_4bit,
30
+ )
31
+
32
+ # add LoRA adapters so we only need to update 1 to 10% of all parameters!
33
+ model = FastLanguageModel.get_peft_model(
34
+ model,
35
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
36
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",# Reference: https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing#scrollTo=2eSvM9zX_2d3
37
+ "gate_proj", "up_proj", "down_proj",],
38
+ lora_alpha = 16,
39
+ lora_dropout = 0, # Supports any, but = 0 is optimized
40
+ bias = "none", # Supports any, but = "none" is optimized
41
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
42
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
43
+ random_state = 3407,
44
+ use_rslora = False, # We support rank stabilized LoRA
45
+ loftq_config = None, # And LoftQ
46
+ )
47
+
48
+ tokenizer = get_chat_template(
49
+ tokenizer,
50
+ chat_template = "llama-3.1",
51
+ )
52
+
53
+ def formatting_prompts_func(data):
54
+ convos = data['conversations']
55
+ texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
56
+
57
+ return { "text" : texts, }
58
+
59
+ def train(train_dataset, val_dataset):
60
+ # format the dataset
61
+ texts_train = train_dataset.map(formatting_prompts_func, batched=True)
62
+ print(texts_train)
63
+ texts_val = val_dataset.map(formatting_prompts_func, batched=True)
64
+ print(texts_val)
65
+
66
+ # load trainer
67
+ trainer = SFTTrainer(
68
+ model = model,
69
+ tokenizer = tokenizer,
70
+ train_dataset = texts_train,
71
+ eval_dataset= texts_val,
72
+ dataset_text_field = "text",
73
+ max_seq_length = max_seq_length,
74
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), # A function or mechanism to batch and prepare data during training and evaluation
75
+ dataset_num_proc = 2, # number of processes to use for data preprocessing
76
+ packing = False, # Can make training 5x faster for short sequences.
77
+ args = TrainingArguments(
78
+ per_device_train_batch_size = 4,
79
+ gradient_accumulation_steps = 8,
80
+ warmup_steps = 5,
81
+ num_train_epochs = 100, # Set this for 1 full training run. # TODO: increase the value
82
+ max_steps = 1000, # TODO: increase the value
83
+ learning_rate = 2e-4,
84
+ fp16 = not is_bfloat16_supported(),
85
+ bf16 = is_bfloat16_supported(),
86
+ logging_steps = 1,
87
+ optim = "adamw_8bit",
88
+ weight_decay = 0.01,
89
+ lr_scheduler_type = "linear",
90
+ seed = 3407,
91
+ output_dir = "outputs",
92
+ do_eval=True,
93
+ report_to='wandb',
94
+ evaluation_strategy="epoch", # Specifies that evaluations will happen at the end of each epoch, using texts_val for metrics calculation. If we set the evaluation strategy without passing evaluation dataset, there is an error.
95
+ save_strategy="epoch", # Save checkpoints based on epoch.
96
+ load_best_model_at_end=True, # Load the best one from the disk to memory. Otherwise, the model is still the one trained after last epoch.
97
+ metric_for_best_model="loss", # Evaluation metric on evaluation set. Make sure the metric is an option included in TrainingArguments.
98
+ greater_is_better=False, # The metric is "greater_is_better".
99
+ save_total_limit=1 # How many checkpoints will be saved.
100
+ ),
101
+ )
102
+
103
+ # train only on outputs and ignore the loss of user's inputs
104
+ trainer = train_on_responses_only(
105
+ trainer,
106
+ instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
107
+ response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
108
+ )
109
+
110
+ trainer_stats = trainer.train()
111
+
112
+ # save the model
113
+ checkpoint_dir = f"./checkpoints/{dataset_name}"
114
+ os.makedirs(checkpoint_dir, exist_ok=True)
115
+ checkpoint_path = os.path.join(checkpoint_dir, f"lora_model")
116
+ model.save_pretrained(checkpoint_path) # Local saving
117
+ tokenizer.save_pretrained(checkpoint_path)
118
+
119
+ print("Training completed.")
120
+
121
+ return checkpoint_path
122
+
123
+ def evaluate(test_dataset, checkpoint_path):
124
+
125
+ # load model
126
+ model, tokenizer = FastLanguageModel.from_pretrained(
127
+ model_name = checkpoint_path,
128
+ max_seq_length = max_seq_length, # specify the maximum length of input the model can accept
129
+ dtype = dtype,
130
+ load_in_4bit = load_in_4bit,
131
+ )
132
+
133
+ FastLanguageModel.for_inference(model)
134
+
135
+ # evaluate
136
+ acc = 0
137
+ for idx, dp in enumerate(test_dataset): # TODO: batch
138
+ message = dp['conversations'][0]
139
+ label = dp['conversations'][1]
140
+ assert label['role'] == "assistant"
141
+
142
+ # format the input
143
+ inputs = tokenizer.apply_chat_template([message], tokenize = True, add_generation_prompt = True, return_tensors = 'pt').to('cuda')
144
+ print(f"222, {inputs}")
145
+
146
+ text_streamer = TextStreamer(tokenizer, skip_prompt = True)
147
+ outputs = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,
148
+ use_cache = True, temperature = 1.5, min_p = 0.1)
149
+ outputs = tokenizer.batch_decode(outputs)
150
+ parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")
151
+ # print(f"111, {parts}")
152
+ results = parts[1].strip("<|eot_id|>")
153
+ results = contractions.fix(results)
154
+ try:
155
+ results = ast.literal_eval(results)
156
+ except:
157
+ try:
158
+ results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results)
159
+ results = ast.literal_eval(results)
160
+ except:
161
+ results = {
162
+ "Metapath":"",
163
+ "Restriction":{},
164
+ }
165
+
166
+
167
+
168
+ pred_metapath = results['Metapath']
169
+
170
+ ground_metapath = ast.literal_eval(label['content'])['Metapath']
171
+ if pred_metapath == ground_metapath:
172
+ acc += 1
173
+ print(f"Prediction: {pred_metapath}")
174
+ print(f"Ground truth: {ground_metapath}")
175
+
176
+
177
+
178
+ print(f"Accuracy: {acc / len(test_dataset)}")
179
+
180
+
181
+ def main(dataset_name, model_name):
182
+ # *****load dataset*****
183
+ data_dir = f"./data/finetune"
184
+ data_path = os.path.join(data_dir, f"{dataset_name}/llama_ft.jsonl")
185
+ # data_path = os.path.join(data_dir, f"{dataset_name}/llama_ft_{model_name}.jsonl")
186
+ dataset = load_dataset("json", data_files=data_path)
187
+ dataset = dataset['train']
188
+
189
+ # Add an index column to keep track of original indices
190
+ dataset = dataset.add_column("index", list(range(len(dataset))))
191
+
192
+ # Perform train-test split
193
+ train_test = dataset.train_test_split(test_size=0.2, seed=42)
194
+ val_test = train_test['test'].train_test_split(test_size=0.5, seed=42)
195
+
196
+ # Extract the datasets
197
+ train_dataset = train_test['train']
198
+ val_dataset = val_test['train']
199
+ test_dataset = val_test['test']
200
+
201
+ # Remove the index column if not needed
202
+ train_dataset = train_dataset.remove_columns(['index'])
203
+ val_dataset = val_dataset.remove_columns(['index'])
204
+ test_dataset = test_dataset.remove_columns(['index'])
205
+
206
+ # *****train and evaluate*****
207
+ checkpoint_path = train(train_dataset, val_dataset)
208
+
209
+
210
+ checkpoint_path = f"./checkpoints/{dataset_name}/lora_model"
211
+ # checkpoint_path = f"./checkpoints/{dataset_name}/lora_model_{model_name}"
212
+ evaluate(test_dataset, checkpoint_path)
213
+
214
+ if __name__ == "__main__":
215
+ dataset_name_list = ['mag', 'amazon', 'prime']
216
+ model_names = ["4o"]
217
+ for dataset_name in dataset_name_list:
218
+ for model_name in model_names:
219
+ main(dataset_name, model_name)
220
+
221
+
222
+
223
+
Planning/model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch.nn as nn
4
+ import ast
5
+ from unsloth import FastLanguageModel
6
+ from transformers import TextStreamer
7
+ import contractions
8
+ import re
9
+ from Planning.utils import remove_inner_single_quotes
10
+
11
+
12
+ class Planner(nn.Module):
13
+ def __init__(self, dataset_name):
14
+ super(Planner, self).__init__()
15
+ self.dataset_name = dataset_name
16
+ self.checkpoint_path = f"Planning/checkpoints/{dataset_name}/lora_model/"
17
+ self.max_seq_length = 2048
18
+ self.dtype = None
19
+ self.load_in_4bit = True
20
+
21
+ model, tokenizer = FastLanguageModel.from_pretrained(
22
+ model_name = self.checkpoint_path,
23
+ max_seq_length = self.max_seq_length,
24
+ dtype = self.dtype,
25
+ load_in_4bit = self.load_in_4bit
26
+ )
27
+
28
+ FastLanguageModel.for_inference(model)
29
+
30
+ self.model = model
31
+ self.tokenizer = tokenizer
32
+
33
+ def forward(self, query):
34
+ message = {'content': query, 'role': 'user'}
35
+ inputs = self.tokenizer.apply_chat_template(
36
+ [message],
37
+ tokenize = True,
38
+ add_generation_prompt = True, # Must add for generation
39
+ return_tensors = "pt",
40
+ ).to("cuda")
41
+
42
+ text_streamer = TextStreamer(self.tokenizer, skip_prompt = True)
43
+ outputs = self.model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128, # max_new_tokens is the maximum number of new tokens generated beyond the input
44
+ use_cache = True, temperature = 1.5, min_p = 0.1) # min_p is a cumulative probability, which makes the generation more diverse
45
+
46
+
47
+ outputs = self.tokenizer.batch_decode(outputs)
48
+ parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")
49
+
50
+ if len(parts) > 1:
51
+ results = parts[1].replace("<|eot_id|>", "")
52
+ else:
53
+ raise ValueError
54
+
55
+
56
+ # ******* special processing for prime dataset
57
+ if self.dataset_name == 'prime':
58
+ try:
59
+ # Parse the string using ast.literal_eval
60
+ parsed_dict = ast.literal_eval(results)
61
+
62
+ return parsed_dict
63
+ except (SyntaxError, ValueError) as e:
64
+ print(f"Error parsing the string: {e}")
65
+ return {
66
+ "Metapath": "",
67
+ "Restriction": {}
68
+ }
69
+
70
+
71
+ results = contractions.fix(results)
72
+
73
+ try:
74
+ results = ast.literal_eval(results)
75
+ except:
76
+ print(f"Fail")
77
+ try:
78
+ results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results) # TODO: need optimize
79
+ results = ast.literal_eval(results)
80
+ except:
81
+ results = {
82
+ "Metapath": "",
83
+ "Restriction": {},
84
+
85
+ }
86
+ rg = results
87
+
88
+
89
+ return rg
Planning/utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def remove_inner_single_quotes(match):
2
+ content = match.group(1) # Get the string inside ['']
3
+ cleaned_content = content.replace("'", "") # Remove inner single quotes
4
+ return f"['{cleaned_content}"
Reasoning/__pycache__/mor4node.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
Reasoning/__pycache__/mor4node_copy.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
Reasoning/__pycache__/mor4path.cpython-311.pyc ADDED
Binary file (22.7 kB). View file
 
Reasoning/__pycache__/ptp_mor4node.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
Reasoning/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.99 kB). View file
 
Reasoning/mor4path.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ # Get the absolute path of the current script
4
+ current_file = Path(__file__).resolve()
5
+ project_root = current_file.parents[2]
6
+ # Add the project root to the system path
7
+ sys.path.append(str(project_root))
8
+
9
+ from .utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length
10
+ from models.model import ModelForSTaRKQA
11
+
12
+
13
+ class MOR4Path(ModelForSTaRKQA):
14
+ def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100):
15
+ super(MOR4Path, self).__init__(skb)
16
+ self.dataset_name = dataset_name
17
+ self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb)
18
+ self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb)
19
+ # self.scorer = self.text_retriever
20
+ self.topk = topk
21
+ self.node_type_list = skb.node_type_lst()
22
+ self.edge_type_list = skb.rel_type_lst()
23
+ if self.dataset_name == "prime":
24
+ self.tp_list = skb.get_tuples()
25
+ self.target_type_list = skb.candidate_types
26
+ else:
27
+ self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()}
28
+ self.target_type_list = ['paper' if dataset_name == 'mag' else 'product']
29
+
30
+ self.skb = skb
31
+ self.ini_k = 5 # topk for initial retrieval
32
+ self.mor_k = 10 # topk for textual retrieval in MOR
33
+ self.mor_count = 0
34
+ self.num_negs = 200
35
+
36
+
37
+
38
+ def rg2routes(self, rg):
39
+ """
40
+ input: rg: {"Metapath": "", "Restriction": {}}
41
+ output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
42
+ """
43
+ # parse rg
44
+ metapath = rg["Metapath"]
45
+ if isinstance(rg["Metapath"], list):
46
+ routes = rg["Metapath"]
47
+ elif isinstance(rg["Metapath"], str):
48
+ routes = parse_metapath(metapath)
49
+ else:
50
+ return None
51
+
52
+ return routes
53
+
54
+ def check_valid(self, routes, rg):
55
+ # check the length of routes
56
+ if not routes:
57
+ # raise ValueError(f"Empty routes: {routes}")
58
+ return None
59
+
60
+ if len(routes) == 1 and len(routes[0]) == 1: # single node, directly do text retrieval
61
+ return None
62
+
63
+
64
+ # Step 1: Filter routes by target type
65
+ target_type_valid_routes = [
66
+ route for route in routes if route[-1] in self.target_type_list
67
+ ]
68
+ if not target_type_valid_routes:
69
+
70
+ return None
71
+
72
+ # Step 2: Filter routes by node and edge type
73
+ type_valid_routes = [
74
+ route
75
+ for route in target_type_valid_routes
76
+ if all(
77
+ node in self.node_type_list or node in self.edge_type_list
78
+ for node in route
79
+ )
80
+ ]
81
+ if not type_valid_routes:
82
+
83
+ return None
84
+
85
+ # Step 3: Check existence of relations
86
+ relation_valid_routes = []
87
+ for route in type_valid_routes:
88
+ if self.dataset_name == "prime":
89
+ if len(route) < 3:
90
+ continue
91
+ triplets = [
92
+ (route[i], route[i + 1], route[i + 2])
93
+ for i in range(0, len(route) - 2, 2)
94
+ ]
95
+
96
+ if all(tp in self.tp_list for tp in triplets) and all(len(tp) == 3 for tp in triplets): # and all length of triplets is 3
97
+ relation_valid_routes.append(route)
98
+ else:
99
+ pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)]
100
+ if all(tp in self.tp_dict.keys() for tp in pairs):
101
+ relations = [self.tp_dict[tp] for tp in pairs]
102
+
103
+ # make route with relations
104
+ new_route = []
105
+ for i in range(len(relations)):
106
+ new_route.append(pairs[i][0])
107
+ new_route.append(relations[i])
108
+ new_route.append(pairs[-1][-1])
109
+
110
+ relation_valid_routes.append(new_route)
111
+
112
+ if not relation_valid_routes:
113
+
114
+ return None
115
+
116
+ return relation_valid_routes
117
+
118
+ def get_candidates4route(self, query, q_id, route, restriction):
119
+ # initialization
120
+
121
+ ini_node_type = route[0]
122
+
123
+ try:
124
+ type_restr = "".join(restriction[ini_node_type])
125
+ except:
126
+ type_restr = ""
127
+
128
+ ini_dict = self.text_retriever.retrieve(query + " " + type_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type)
129
+ current_node_ids = list(ini_dict.keys())
130
+
131
+
132
+ # initialize the bm_vector_dict
133
+ bm_vector_dict = {key: [value] for key, value in ini_dict.items()}
134
+
135
+ # initilization for paths
136
+ paths = {}
137
+ for c_id in current_node_ids:
138
+ paths[c_id] = [c_id]
139
+
140
+ # loop
141
+ hops = len(route)
142
+ # for hop/layer
143
+ for hop in range(0, hops-2, 2):
144
+ new_paths = {}
145
+
146
+ cur_node_type = route[hop]
147
+ next_node_type = route[hop+2]
148
+ edge_type = route[hop+1]
149
+ next_node_ids = []
150
+
151
+ new_vector_dict = {}
152
+ # for node
153
+ for node_id in current_node_ids:
154
+ neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type)
155
+ next_node_ids.extend(neighbor_ids)
156
+
157
+ # ***** update paths and score_vector_dict *****
158
+ for neighbor_id in neighbor_ids:
159
+ if neighbor_id not in new_paths.keys(): # only add new node
160
+ new_paths[neighbor_id] = paths[node_id] + [neighbor_id]
161
+ new_vector_dict[neighbor_id] = bm_vector_dict[node_id] + [-1] # -1 for padding
162
+
163
+
164
+ bm_vector_dict = new_vector_dict
165
+
166
+
167
+
168
+ # ***** layer text retrieval *****
169
+ # if there is restriction for the next node, add text_retriever
170
+ if next_node_type in restriction.keys() and len(restriction[next_node_type]) > 0 and restriction[next_node_type] != [""]:
171
+ try:
172
+
173
+ retrieve_dict = self.text_retriever.retrieve(query+" "+"".join(restriction[next_node_type]), q_id=q_id, topk=self.mor_k, node_type=route[hop+2])
174
+
175
+ new_query = query+ " " + "".join(restriction[next_node_type])
176
+
177
+ # take union
178
+ next_node_ids.extend(list(set(retrieve_dict.keys())))
179
+
180
+ # ***** update paths and bm_vector_dict *****
181
+ for c_id in retrieve_dict.keys():
182
+ if c_id not in new_paths.keys():
183
+ new_paths[c_id] = [c_id]
184
+ bm_vector_dict[c_id] = [retrieve_dict[c_id]]
185
+
186
+ except:
187
+ pass
188
+
189
+
190
+ paths = new_paths
191
+ current_node_ids = list(set(next_node_ids))
192
+
193
+
194
+ candidates = current_node_ids
195
+
196
+ self.paths.append(paths)
197
+ self.bm_vector_dict.append(bm_vector_dict)
198
+
199
+
200
+ return candidates
201
+
202
+ def merge_candidate_pools(self, non_empty_candidates_lists):
203
+
204
+ # if only one non-empy candidates list left, return it as a set
205
+ if len(non_empty_candidates_lists) == 1:
206
+ return set(non_empty_candidates_lists[0])
207
+
208
+ # find the intersection candidates ids
209
+ result = set(non_empty_candidates_lists[0])
210
+ for lst in non_empty_candidates_lists[1:]:
211
+ result.intersection_update(lst)
212
+
213
+ # if the intersection is empty, return the union of all candidates
214
+ if len(result) == 0:
215
+ result = set()
216
+ for lst in non_empty_candidates_lists:
217
+ result.update(lst)
218
+
219
+ return list(result)
220
+
221
+ def get_mor_candidates(self, query, q_id, valid_routes, restriction):
222
+
223
+ # Step 1: Get candidates for each route
224
+ candidates_pool = []
225
+ for route in valid_routes:
226
+ if route[0] in restriction.keys() and len(restriction[route[0]]) > 0:
227
+ candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) # topk is the candidates retrieved from textual retriever
228
+
229
+
230
+ # remove empty lists from candidates
231
+ non_empty_candidates_lists = [lst for lst in candidates_pool if lst]
232
+ if len(non_empty_candidates_lists) == 0:
233
+
234
+ return {}
235
+
236
+
237
+ # Step 2: Combine candidates from different routes, try intersection first, then union
238
+ candidates = self.merge_candidate_pools(non_empty_candidates_lists) # candidates is a list
239
+
240
+ # step 3: score the candidates, ini to -1
241
+ pred_dict = dict(zip(candidates, [-1]*len(candidates)))
242
+ # print(f"111, {pred_dict}")
243
+
244
+ return pred_dict
245
+
246
+ def check_topk(self, query, q_id, pred_dict):
247
+
248
+ missing = self.topk - len(set(pred_dict.keys()))
249
+ if missing > 0:
250
+ added_dict = self.text_retriever.retrieve(query, q_id, topk=self.topk+20, node_type=self.target_type_list) # +20 make it more safe
251
+ available_nodes = {key: value for key, value in added_dict.items() if key not in pred_dict.keys()}
252
+ sorted_available_nodes = sorted(available_nodes.items(), key=lambda x: x[1], reverse=True)
253
+ # Select only the required number of nodes to fill the missing slots
254
+ selected_nodes = dict(sorted_available_nodes[:missing])
255
+
256
+ # Update pred_dict with the selected nodes
257
+ pred_dict.update(selected_nodes)
258
+
259
+ # updata paths
260
+ for node_id in selected_nodes.keys():
261
+ self.paths[node_id] = [node_id]
262
+
263
+ # update bm_vector_dict
264
+ new_bm_vector_dict = {key: [value] for key, value in selected_nodes.items()}
265
+ self.bm_vector_dict.update(new_bm_vector_dict)
266
+
267
+ scored_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(pred_dict.keys()))
268
+
269
+ if len(scored_dict) > self.topk:
270
+ # initiliaze the new_paths
271
+ new_paths = {}
272
+
273
+ # Select the top-k nodes based on the scores
274
+ sorted_scored_dict = sorted(scored_dict.items(), key=lambda x: x[1], reverse=True)
275
+ scored_dict = dict(sorted_scored_dict[:self.topk])
276
+
277
+ # update paths
278
+ for node_id in scored_dict.keys():
279
+ new_paths[node_id] = self.paths[node_id]
280
+
281
+ self.paths = new_paths
282
+
283
+ # update bm_vector_dict
284
+ new_bm_vector_dict = {node_id: self.bm_vector_dict[node_id] for node_id in scored_dict.keys()}
285
+ self.bm_vector_dict = new_bm_vector_dict
286
+
287
+
288
+ return scored_dict
289
+
290
+
291
+ # check fixed negtopk
292
+ def check_negtopk(self, query, q_id, pred_dict, ans_ids):
293
+ # check the positive nodes
294
+ pos_ids = [node_id for node_id in ans_ids if node_id in pred_dict.keys()]
295
+ pos_dict = {key: value for key, value in pred_dict.items() if key in pos_ids}
296
+ neg_ids = pred_dict.keys() - set(pos_ids)
297
+ neg_dict = {key: value for key, value in pred_dict.items() if key in neg_ids}
298
+
299
+ # check the number of negative nodes
300
+ missing = self.num_negs - len(neg_ids)
301
+
302
+ if missing > 0:
303
+
304
+ added_dict = self.text_retriever.retrieve(query, q_id, topk=self.num_negs+200, node_type=self.target_type_list) # +20 make it more safe
305
+ available_nodes = {key: value for key, value in added_dict.items() if key not in pred_dict.keys() and key not in ans_ids}
306
+ sorted_available_nodes = sorted(available_nodes.items(), key=lambda x: x[1], reverse=True)
307
+ # Select only the required number of nodes to fill the missing slots
308
+ selected_nodes = dict(sorted_available_nodes[:missing])
309
+
310
+ # Update pred_dict with the selected nodes
311
+ neg_dict.update(selected_nodes)
312
+
313
+ # updata paths
314
+ for node_id in selected_nodes.keys():
315
+ self.paths[node_id] = [node_id]
316
+
317
+ # update bm_vector_dict
318
+ new_bm_vector_dict = {key: [value] for key, value in selected_nodes.items()}
319
+ self.bm_vector_dict.update(new_bm_vector_dict)
320
+
321
+
322
+ scored_neg_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(neg_dict.keys()))
323
+ if pos_dict:
324
+ scored_pos_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(pos_dict.keys()))
325
+ else:
326
+ scored_pos_dict = {}
327
+
328
+ if len(scored_neg_dict) > self.num_negs:
329
+ # Select the top-k nodes based on the scores
330
+ sorted_scored_neg_dict = sorted(scored_neg_dict.items(), key=lambda x: x[1], reverse=True)
331
+ scored_neg_dict = dict(sorted_scored_neg_dict[:self.num_negs])
332
+
333
+
334
+
335
+ scored_neg_dict.update(scored_pos_dict)
336
+ scored_dict = scored_neg_dict
337
+ print(len(scored_dict))
338
+
339
+ # update paths
340
+ new_paths = {}
341
+ for node_id in scored_dict.keys():
342
+ new_paths[node_id] = self.paths[node_id]
343
+ self.paths = new_paths
344
+
345
+ # update bm_vector_dict
346
+ new_bm_vector_dict = {node_id: self.bm_vector_dict[node_id] for node_id in scored_dict.keys()}
347
+ self.bm_vector_dict = new_bm_vector_dict
348
+
349
+
350
+ return scored_dict
351
+
352
+
353
+ def forward(self, query, q_id, ans_ids, rg, args):
354
+
355
+ self.paths = []
356
+ self.bm_vector_dict = []
357
+ self.ada_score = {}
358
+ # ***** Structural Retrieval *****
359
+
360
+ # reasoning grpah to routes
361
+ if self.dataset_name == "prime":
362
+ routes = rg["Metapath"]
363
+ else:
364
+ routes = self.rg2routes(rg)
365
+
366
+ # check valid
367
+ valid_routes = self.check_valid(routes, rg) # add check for restriction
368
+
369
+ if valid_routes is None:
370
+ # do textual retrieval
371
+ pred_dict = self.text_retriever.retrieve(query, q_id, topk=self.topk, node_type=self.target_type_list)
372
+
373
+ # update bm_vector_dict
374
+ self.bm_vector_dict = {key: [value] for key, value in pred_dict.items()}
375
+
376
+ else:
377
+ # truncate the valid_routes
378
+ if self.dataset_name == "prime":
379
+ pass
380
+ else:
381
+ valid_routes = [route[-5:] for route in valid_routes]
382
+
383
+
384
+ # do structural retrieval
385
+ restriction = rg["Restriction"]
386
+ pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction)
387
+ self.mor_count += 1
388
+
389
+
390
+ # **** combine paths ****
391
+ if self.paths:
392
+ self.paths = combine_dicts(self.paths, pred_dict=pred_dict) # return dict
393
+ else:
394
+ self.paths = {}
395
+ for node_id in pred_dict.keys():
396
+ self.paths[node_id] = [node_id]
397
+
398
+ # ***** combine bm_vector_dict *****
399
+ if isinstance(self.bm_vector_dict, list):
400
+ self.bm_vector_dict = combine_dicts(self.bm_vector_dict, pred_dict=pred_dict)
401
+
402
+ # **** fix neg for training; fix candidates for testing ****
403
+ if args.mod == "train":
404
+ # check neg topk
405
+ pred_dict = self.check_negtopk(query, q_id, pred_dict, ans_ids)
406
+ else:
407
+ # check topk
408
+ pred_dict = self.check_topk(query, q_id, pred_dict)
409
+
410
+
411
+ # **** length padding and truncate *****
412
+ if self.dataset_name != "prime":
413
+ self.paths = fix_length(self.paths)
414
+
415
+ if len(self.paths) != len(pred_dict):
416
+ print(f"paths: {self.paths}")
417
+ print(f"pred_dict: {pred_dict}")
418
+ raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}")
419
+
420
+ output = {
421
+ "query": query,
422
+ "pred_dict": pred_dict,
423
+ "ans_ids": ans_ids,
424
+ 'paths': self.paths,
425
+ 'bm_vector_dict': self.bm_vector_dict,
426
+ 'rg': rg
427
+ }
428
+
429
+
430
+ return output
431
+
432
+ if __name__ == "__main__":
433
+ print(f"Test mor4path")
434
+
435
+
Reasoning/structural_retriever/__pycache__/stru4path.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
Reasoning/structural_retriever/stru4path.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ input: rg
3
+ output (fixed 100 candidates, for path-based reranking):
4
+ {
5
+ "query": query,
6
+ "pred_dict": {node_id: score},
7
+ "ans_ids": [],
8
+ 'paths': {node_id: [node_ids_path]}
9
+ }
10
+
11
+ """
12
+ import sys
13
+ import os
14
+ sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
15
+
16
+ from utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length
17
+ from models.model import ModelForSTaRKQA
18
+ import time
19
+
20
+
21
+
22
+ class Stru4Path(ModelForSTaRKQA):
23
+ def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100):
24
+ super(Stru4Path, self).__init__(skb)
25
+ self.dataset_name = dataset_name
26
+ self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb)
27
+ self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb)
28
+ # self.scorer = self.text_retriever
29
+ self.topk = topk
30
+ self.node_type_list = skb.node_type_lst()
31
+ self.edge_type_list = skb.rel_type_lst()
32
+ if self.dataset_name == "prime":
33
+ self.tp_list = skb.get_tuples()
34
+ self.target_type_list = skb.candidate_types
35
+ else:
36
+ self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()}
37
+ self.target_type_list = ['paper' if dataset_name == 'mag' else 'product']
38
+
39
+ self.skb = skb
40
+ self.ini_k = 5 # topk for initial retrieval
41
+ self.stru_count = 0
42
+
43
+
44
+
45
+
46
+ def rg2routes(self, rg):
47
+ """
48
+ input: rg: {"Metapath": "", "Restriction": {}}
49
+ output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
50
+ """
51
+ # parse rg
52
+ metapath = rg["Metapath"]
53
+ if isinstance(rg["Metapath"], list):
54
+ routes = rg["Metapath"]
55
+ elif isinstance(rg["Metapath"], str):
56
+ routes = parse_metapath(metapath)
57
+ else:
58
+ return None
59
+
60
+ return routes
61
+
62
+ def check_valid(self, routes, rg):
63
+ # check the length of routes
64
+ if not routes:
65
+ # raise ValueError(f"Empty routes: {routes}")
66
+ return None
67
+
68
+ if len(routes) == 1 and len(routes[0]) == 1: # single node, directly do text retrieval
69
+ return 1
70
+
71
+ # Step 1: Filter routes by target type
72
+ target_type_valid_routes = [
73
+ route for route in routes if route[-1] in self.target_type_list
74
+ ]
75
+ if not target_type_valid_routes:
76
+ return None
77
+
78
+ # Step 2: Filter routes by node and edge type
79
+ type_valid_routes = [
80
+ route
81
+ for route in target_type_valid_routes
82
+ if all(
83
+ node in self.node_type_list or node in self.edge_type_list
84
+ for node in route
85
+ )
86
+ ]
87
+ if not type_valid_routes:
88
+ return None
89
+
90
+ # Step 3: Check existence of relations
91
+ relation_valid_routes = []
92
+ for route in type_valid_routes:
93
+ if self.dataset_name == "prime":
94
+ triplets = [
95
+ (route[i], route[i + 1], route[i + 2])
96
+ for i in range(0, len(route) - 2, 2)
97
+ ]
98
+
99
+ if all(tp in self.tp_list for tp in triplets):
100
+ relation_valid_routes.append(route)
101
+ else:
102
+ pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)]
103
+ if all(tp in self.tp_dict.keys() for tp in pairs):
104
+ relations = [self.tp_dict[tp] for tp in pairs]
105
+
106
+ # make route with relations
107
+ new_route = []
108
+ for i in range(len(relations)):
109
+ new_route.append(pairs[i][0])
110
+ new_route.append(relations[i])
111
+ new_route.append(pairs[-1][-1])
112
+ # print(f"222, {new_route}")
113
+
114
+ relation_valid_routes.append(new_route)
115
+
116
+ if not relation_valid_routes:
117
+ return None
118
+
119
+ return relation_valid_routes
120
+
121
+ def get_candidates4route(self, query, q_id, route, restriction):
122
+ # initialization
123
+
124
+ ini_node_type = route[0]
125
+
126
+ try:
127
+ extra_restr = "".join(restriction[ini_node_type])
128
+ except:
129
+ extra_restr = ""
130
+ ini_dict = self.text_retriever.retrieve(query + " " + extra_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type)
131
+ current_node_ids = list(ini_dict.keys())
132
+
133
+ # initilization for paths
134
+ paths = {}
135
+ for c_id in current_node_ids:
136
+ paths[c_id] = [c_id]
137
+
138
+ # loop
139
+ hops = len(route)
140
+ # for hop/layer
141
+ for hop in range(0, hops-2, 2):
142
+ new_paths = {}
143
+
144
+ cur_node_type = route[hop]
145
+ next_node_type = route[hop+2]
146
+ edge_type = route[hop+1]
147
+ next_node_ids = []
148
+
149
+ # for node
150
+ for node_id in current_node_ids:
151
+ neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type)
152
+ next_node_ids.extend(neighbor_ids)
153
+
154
+ # **x*** update paths *****
155
+ for neighbor_id in neighbor_ids:
156
+ new_paths[neighbor_id] = paths[node_id] + [neighbor_id]
157
+
158
+
159
+ paths = new_paths
160
+
161
+ current_node_ids = list(set(next_node_ids))
162
+
163
+ candidates = current_node_ids
164
+ self.paths.append(paths)
165
+
166
+
167
+ return candidates
168
+
169
+ def merge_candidate_pools(self, non_empty_candidates_lists):
170
+
171
+
172
+ # if only one non-empy candidates list left, return it as a set
173
+ if len(non_empty_candidates_lists) == 1:
174
+ return set(non_empty_candidates_lists[0])
175
+ # find the intersection candidates ids
176
+ result = set(non_empty_candidates_lists[0])
177
+ for lst in non_empty_candidates_lists[1:]:
178
+ result.intersection_update(lst)
179
+
180
+ # if the intersection is empty, return the union of all candidates
181
+ if len(result) == 0:
182
+ result = set()
183
+ for lst in non_empty_candidates_lists:
184
+ result.update(lst)
185
+
186
+
187
+
188
+ return list(result)
189
+
190
+ def get_mor_candidates(self, query, q_id, valid_routes, restriction):
191
+
192
+ # Step 1: Get candidates for each route
193
+ candidates_pool = []
194
+ for route in valid_routes:
195
+ if route[0] in restriction.keys() and len(restriction[route[0]]) > 0:
196
+ candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) # topk is the candidates retrieved from textual retriever
197
+
198
+ non_empty_candidates_lists = [lst for lst in candidates_pool if lst]
199
+ if not non_empty_candidates_lists: # no candidates, return empty dict
200
+ print(f"123, {non_empty_candidates_lists}")
201
+
202
+ # raise ValueError("No candidates for any route")
203
+ return {}
204
+
205
+
206
+ # Step 2: Combine candidates from different routes, try intersection first, then union
207
+ candidates = self.merge_candidate_pools(candidates_pool) # candidates is a list
208
+ if not candidates:
209
+ return {}
210
+
211
+
212
+ # step 3: score the candidates, ini to -1
213
+ pred_dict = dict(zip(candidates, [-1]*len(candidates)))
214
+ # print(f"111, {pred_dict}")
215
+
216
+ return pred_dict
217
+
218
+
219
+
220
+ def forward(self, query, q_id, ans_ids, rg):
221
+
222
+ self.paths = []
223
+ # ***** Structural Retrieval *****
224
+
225
+ # reasoning grpah to routes
226
+ s_time = time.time()
227
+ routes = self.rg2routes(rg)
228
+ # print(f"444, {time.time()-s_time}")
229
+
230
+ # check valid
231
+ s_time = time.time()
232
+ valid_routes = self.check_valid(routes, rg) # add check for restriction
233
+ # print(f"555, {time.time()-s_time}")
234
+
235
+ if valid_routes is None:
236
+ # return empty dict
237
+ return {
238
+ "query": query,
239
+ "pred_dict": {},
240
+ "ans_ids": ans_ids,
241
+ 'paths': {},
242
+ 'query_pattern': rg['Metapath']
243
+ }
244
+ elif valid_routes == 1: # TODO: empty string
245
+ print(f"1234: {valid_routes}")
246
+ # do text retrieval
247
+ pred_dict = self.text_retriever.retrieve(query, q_id=q_id, topk=self.topk, node_type=f'{self.target_type_list[0]}')
248
+
249
+ else:
250
+ # do structural retrieval
251
+ # truncate the valid_routes
252
+ if self.dataset_name == "prime":
253
+ pass
254
+ else:
255
+ valid_routes = [route[-5:] for route in valid_routes]
256
+
257
+ restriction = rg["Restriction"]
258
+ pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction)
259
+ self.stru_count += 1
260
+
261
+ # **** combine paths ****
262
+ if self.paths:
263
+ self.paths = combine_dicts(self.paths, pred_dict=pred_dict) # return dict
264
+
265
+ else:
266
+ self.paths = {}
267
+ for node_id in pred_dict.keys():
268
+ self.paths[node_id] = [node_id]
269
+
270
+ # if retrieved candidates is empty, return empty dict
271
+ if not pred_dict:
272
+ return {
273
+ "query": query,
274
+ "pred_dict": {},
275
+ "ans_ids": ans_ids,
276
+ 'paths': {},
277
+ 'query_pattern': rg['Metapath']
278
+ }
279
+
280
+ # score the candidates
281
+ pred_dict = self.scorer.score(query, q_id, list(pred_dict.keys()))
282
+
283
+ # # **** length padding and truncate *****
284
+ # self.paths = fix_length(self.paths)
285
+
286
+ if len(self.paths) != len(pred_dict):
287
+ print(f"paths: {self.paths}")
288
+ print(f"pred_dict: {pred_dict}")
289
+ raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}")
290
+
291
+ output = {
292
+ "query": query,
293
+ "pred_dict": pred_dict,
294
+ "ans_ids": ans_ids,
295
+ 'paths': self.paths,
296
+ 'query_pattern': rg['Metapath'],
297
+ 'rg': rg
298
+ }
299
+
300
+
301
+ return output
302
+
303
+
304
+
305
+
Reasoning/text_retrievers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .stark_model import ModelForSTaRKQA
2
+ from .ada import Ada
3
+ from .bm25 import BM25
4
+ from .contriever import Contriever
Reasoning/text_retrievers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (408 Bytes). View file
 
Reasoning/text_retrievers/__pycache__/ada.cpython-311.pyc ADDED
Binary file (5.39 kB). View file
 
Reasoning/text_retrievers/__pycache__/bm25.cpython-311.pyc ADDED
Binary file (7.02 kB). View file
 
Reasoning/text_retrievers/__pycache__/contriever.cpython-311.pyc ADDED
Binary file (5.69 kB). View file
 
Reasoning/text_retrievers/__pycache__/stark_model.cpython-311.pyc ADDED
Binary file (8.37 kB). View file
 
Reasoning/text_retrievers/ada.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ input: query, query_id, candidates_ids
4
+ output: pred_dict: {node_id: similarity}
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ # Get the absolute path of the current script
10
+ current_file = Path(__file__).resolve()
11
+ project_root = current_file.parents[2]
12
+ # Add the project root to the system path
13
+ sys.path.append(str(project_root))
14
+ import torch
15
+ from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
16
+
17
+ class Ada(ModelForSTaRKQA):
18
+ def __init__(self, skb, dataset_name, device):
19
+ super(Ada, self).__init__(skb)
20
+ self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/"
21
+ self.query_emb_path = self.emb_dir + "text-embedding-ada-002/query/query_emb_dict.pt"
22
+ self.query_emb_dict = torch.load(self.query_emb_path)
23
+ # print(f"777, {self.query_emb_path}")
24
+
25
+ self.candidate_emb_path = self.emb_dir + "text-embedding-ada-002/doc/candidate_emb_dict.pt"
26
+ self.candidate_emb_dict = torch.load(self.candidate_emb_path)
27
+ self.device = device
28
+
29
+ assert len(self.candidate_emb_dict) == len(self.candidate_ids)
30
+
31
+ candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
32
+ self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
33
+
34
+ def score(self, query, q_id, candidate_ids):
35
+ """
36
+ pred_dict[node_id] = similarity (tensor)
37
+
38
+ """
39
+ # Dimension of query_emb: torch.Size([1, emb_dim])
40
+ query_emb = self.query_emb_dict[q_id].view(1, -1)
41
+ # Dimension of candidates_embs: torch.Size([num_candidates, emb_dim])
42
+ # # candidates_embs = self.candidate_embs[candidates_ids]
43
+ candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids]
44
+ candidates_embs = torch.cat(candi_embs, dim=0).to(self.device)
45
+ # Dimension of similarity: torch.Size([num_candidates])
46
+ similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu()
47
+ pred_dict = {}
48
+ for i in range(len(candidate_ids)):
49
+ pred_dict[candidate_ids[i]] = similarity[i].item()
50
+
51
+ return pred_dict
52
+
53
+ def retrieve(self, query, q_id, topk, node_type=None):
54
+ # Dimension of query_emb: torch.Size([1, emb_dim])
55
+ query_emb = self.query_emb_dict[q_id].view(1, -1)
56
+
57
+ similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
58
+ if isinstance(query, str):
59
+ pred_dict = dict(zip(self.candidate_ids, similarity.view(-1)))
60
+
61
+ sorted_pred_ids = sorted(pred_dict, key=lambda x: pred_dict[x], reverse=True)
62
+ selected_pred_ids = sorted_pred_ids[:topk]
63
+ pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids}
64
+ print(f"sorted: {pred_dict}")
65
+
66
+ return pred_dict
Reasoning/text_retrievers/bm25.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ input: query, node_type, topk
3
+ output: pred_dict: {node_id: score}
4
+
5
+ """
6
+ import bm25s
7
+ from tqdm import tqdm
8
+ import sys
9
+ from pathlib import Path
10
+ # Get the absolute path of the current script
11
+ current_file = Path(__file__).resolve()
12
+ project_root = current_file.parents[2]
13
+ # Add the project root to the system path
14
+ sys.path.append(str(project_root))
15
+
16
+ from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
17
+
18
+ target_type = {'amazon': 'product', 'prime': 'combine', 'mag': 'paper'}
19
+
20
+ class BM25(ModelForSTaRKQA):
21
+
22
+ def __init__(self, skb, dataset_name):
23
+ super(BM25, self).__init__(skb)
24
+ self.retrievers = {}
25
+ self.text_to_ids = {}
26
+ type_names = skb.node_type_lst()
27
+ self.nodeid_to_index = {}
28
+
29
+ self.target_type = target_type[dataset_name]
30
+
31
+ if self.target_type not in type_names:
32
+ ids = skb.get_candidate_ids()
33
+
34
+
35
+ corpus = [skb.get_doc_info(id) for id in tqdm(ids, desc=f"Gathering docs for combine")]
36
+ retriever = bm25s.BM25(corpus=corpus)
37
+ retriever.index(bm25s.tokenize(corpus))
38
+ # Build hash map from text to node_id
39
+ text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
40
+ # Store the retriever and text_to_id by type_name
41
+ self.retrievers[self.target_type] = retriever
42
+ self.text_to_ids[self.target_type] = text_to_id
43
+
44
+ self.nodeid_to_index[self.target_type] = {id: i for i, id in enumerate(ids)}
45
+
46
+ # Initialize retrievers and text-to-index maps for each type_name
47
+ for type_name in type_names:
48
+ ids = skb.get_node_ids_by_type(type_name)
49
+
50
+ # we manually replace '&' with '_and_' to avoid the error in BM25, because BM25 uses '&' as a special character and will not tokenize it
51
+ corpus = [skb.get_doc_info(id).replace('&', '_and_').replace('P.O.R', 'P_dot_O_dot_R') for id in tqdm(ids, desc=f"Gathering docs for {type_name}")]
52
+ # Create the BM25 model for the current type_name
53
+ retriever = bm25s.BM25(corpus=corpus)
54
+ retriever.index(bm25s.tokenize(corpus))
55
+
56
+ # Build hash map from text to index
57
+ text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
58
+
59
+ # Store the retriever and text_to_id by type_name
60
+ self.retrievers[type_name] = retriever
61
+ self.text_to_ids[type_name] = text_to_id
62
+
63
+ # build map from node_id to index
64
+ self.nodeid_to_index[type_name] = {id: i for i, id in enumerate(ids)}
65
+
66
+ def score(self, query, q_id, candidate_ids):
67
+ pred_dict = {}
68
+
69
+ for c_id in candidate_ids:
70
+ type_name = self.skb.get_node_type_by_id(c_id)
71
+ score = self.retrievers[type_name].get_scores(list(bm25s.tokenize(query)[1].keys()))[self.nodeid_to_index[type_name][c_id]] # save the query tokens
72
+ pred_dict[c_id] = score
73
+
74
+ # print(f"999, {pred_dict}")
75
+
76
+ return pred_dict
77
+
78
+ def retrieve(self, query, q_id, topk, node_type=None):
79
+
80
+ """
81
+ Forward pass to compute similarity scores for the given query.
82
+
83
+ Args:
84
+ query (str): Query string.
85
+
86
+ Returns:
87
+ pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
88
+ """
89
+ if '&' in query:
90
+ query = query.replace('&', '_and_')
91
+ if 'P.O.R' in query:
92
+ query = query.replace('P.O.R', 'P_dot_O_dot_R')
93
+ if isinstance(node_type, list):
94
+ if len(node_type) > 1:
95
+ node_type = 'combine'
96
+ else:
97
+ node_type = node_type[0]
98
+ results, scores = self.retrievers[node_type].retrieve(bm25s.tokenize(query), k=topk)
99
+ ids = [self.text_to_ids[node_type][hash(result.item())] for result in results[0]]
100
+ scores = scores[0].tolist()
101
+ pred_dict = dict(zip(ids, scores))
102
+ # print(f"666, {pred_dict}")
103
+
104
+ return pred_dict
105
+
106
+ if __name__ == '__main__':
107
+ print(f"Testing BM25")
108
+
Reasoning/text_retrievers/contriever.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ input: query, query_id, candidates_ids
4
+ output: pred_dict: {node_id: similarity}
5
+ """
6
+ import heapq
7
+ import sys
8
+ from pathlib import Path
9
+ # Get the absolute path of the current script
10
+ current_file = Path(__file__).resolve()
11
+ project_root = current_file.parents[2]
12
+ # Add the project root to the system path
13
+ sys.path.append(str(project_root))
14
+ from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings
15
+ import torch
16
+ from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
17
+
18
+ class Contriever(ModelForSTaRKQA):
19
+ def __init__(self, skb, dataset_name, device):
20
+ super(Contriever, self).__init__(skb)
21
+ self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/"
22
+
23
+ self.query_emb_path = self.emb_dir + "contriever/query_no_rel_no_compact/query_emb_dict.pt"
24
+ self.query_emb_dict = torch.load(self.query_emb_path)
25
+
26
+ self.candidate_emb_path = self.emb_dir + "contriever/doc_no_rel_no_compact/candidate_emb_dict.pt"
27
+ self.candidate_emb_dict = torch.load(self.candidate_emb_path)
28
+ self.device = device
29
+
30
+ assert len(self.candidate_emb_dict) == len(self.candidate_ids)
31
+
32
+ candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
33
+ self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
34
+
35
+ # load contriever for query embeddings
36
+ self.encoder, self.tokenizer = get_contriever(dataset_name=dataset_name)
37
+ self.encoder = self.encoder.to(device)
38
+
39
+
40
+ def score(self, query, q_id, candidate_ids):
41
+ """
42
+ pred_dict[node_id] = similarity (tensor)
43
+
44
+ """
45
+
46
+
47
+ # Dimension of query_emb: torch.Size([1, emb_dim])
48
+ query_emb = self.query_emb_dict[q_id].view(1, -1)
49
+
50
+ # Dimension of candidates_embs: torch.Size([num_candidates, emb_dim])
51
+ candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids]
52
+ candidates_embs = torch.cat(candi_embs, dim=0).to(self.device)
53
+ # Dimension of similarity: torch.Size([num_candidates])
54
+ similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu()
55
+ pred_dict = {}
56
+ for i in range(len(candidate_ids)):
57
+ pred_dict[candidate_ids[i]] = similarity[i].item()
58
+
59
+ return pred_dict
60
+
61
+ def retrieve(self, query, q_id, topk, node_type=None):
62
+ # Dimension of query_emb: torch.Size([1, emb_dim])
63
+ query_emb = get_contriever_embeddings(query, encoder=self.encoder, tokenizer=self.tokenizer, device=self.device)
64
+
65
+ similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
66
+
67
+ if isinstance(query, str):
68
+ pred_dict = dict(zip(self.candidate_ids, similarity.view(-1)))
69
+
70
+
71
+ selected_pred_ids = heapq.nlargest(topk, pred_dict, key=pred_dict.get)
72
+ pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids}
73
+
74
+
75
+ return pred_dict
76
+
77
+ if __name__ == '__main__':
78
+ print("Testing Contriever")
Reasoning/text_retrievers/stark_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from typing import Any, Union, List, Dict
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ # Get the absolute path of the current script
10
+ current_file = Path(__file__).resolve()
11
+ project_root = current_file.parents[2]
12
+ # Add the project root to the system path
13
+ sys.path.append(str(project_root))
14
+
15
+ from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings
16
+ from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
17
+ from stark_qa.evaluator import Evaluator
18
+
19
+
20
+ class ModelForSTaRKQA(nn.Module):
21
+
22
+ def __init__(self, skb, query_emb_dir='.'):
23
+ """
24
+ Initializes the model with the given knowledge base.
25
+
26
+ Args:
27
+ skb: Knowledge base containing candidate information.
28
+ """
29
+ super(ModelForSTaRKQA, self).__init__()
30
+ self.skb = skb
31
+
32
+ self.candidate_ids = skb.candidate_ids
33
+ self.num_candidates = skb.num_candidates
34
+ self.query_emb_dir = query_emb_dir
35
+
36
+ query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
37
+ if os.path.exists(query_emb_path):
38
+ print(f'Load query embeddings from {query_emb_path}')
39
+ self.query_emb_dict = torch.load(query_emb_path)
40
+ else:
41
+ self.query_emb_dict = {}
42
+ self.evaluator = Evaluator(self.candidate_ids)
43
+
44
+ def forward(self,
45
+ query: Union[str, List[str]],
46
+ candidates: List[int] = None,
47
+ query_id: Union[int, List[int]] = None,
48
+ **kwargs: Any) -> Dict[str, Any]:
49
+ """
50
+ Forward pass to compute predictions for the given query.
51
+
52
+ Args:
53
+ query (Union[str, list]): Query string or a list of query strings.
54
+ candidates (Union[list, None]): A list of candidate ids (optional).
55
+ query_id (Union[int, list, None]): Query index (optional).
56
+
57
+ Returns:
58
+ pred_dict (dict): A dictionary of predicted scores or answer ids.
59
+ """
60
+ raise NotImplementedError
61
+
62
+ def get_query_emb(self,
63
+ query: Union[str, List[str]],
64
+ query_id: Union[int, List[int]],
65
+ emb_model: str = 'text-embedding-ada-002',
66
+ **encode_kwargs) -> torch.Tensor:
67
+ """
68
+ Retrieves or computes the embedding for the given query.
69
+
70
+ Args:
71
+ query (str): Query string.
72
+ query_id (int): Query index.
73
+ emb_model (str): Embedding model to use.
74
+
75
+ Returns:
76
+ query_emb (torch.Tensor): Query embedding.
77
+ """
78
+ if isinstance(query_id, int):
79
+ query_id = [query_id]
80
+ if isinstance(query, str):
81
+ query = [query]
82
+
83
+ encode_kwargs['is_query'] = True
84
+ if query_id is None:
85
+ query_emb = get_embeddings(query, emb_model, **encode_kwargs)
86
+ elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))):
87
+ query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0)
88
+ else:
89
+ query_emb = get_embeddings(query, emb_model, **encode_kwargs)
90
+ for qid, emb in zip(query_id, query_emb):
91
+ self.query_emb_dict[qid] = emb.view(1, -1)
92
+ torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt'))
93
+
94
+ query_emb = query_emb.view(len(query), -1)
95
+ return query_emb
96
+
97
+ def evaluate(self,
98
+ pred_dict: Dict[int, float],
99
+ answer_ids: Union[torch.LongTensor, List[Any]],
100
+ metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
101
+ **kwargs: Any) -> Dict[str, float]:
102
+ """
103
+ Evaluates the predictions using the specified metrics.
104
+
105
+ Args:
106
+ pred_dict (Dict[int, float]): Predicted answer ids or scores.
107
+ answer_ids (torch.LongTensor): Ground truth answer ids.
108
+ metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
109
+ 'precision@k', 'map@k', 'ndcg@k'.
110
+
111
+ Returns:
112
+ Dict[str, float]: A dictionary of evaluation metrics.
113
+ """
114
+ return self.evaluator(pred_dict, answer_ids, metrics)
115
+
116
+ def evaluate_batch(self,
117
+ pred_ids: List[int],
118
+ pred: torch.Tensor,
119
+ answer_ids: Union[torch.LongTensor, List[Any]],
120
+ metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
121
+ **kwargs: Any) -> Dict[str, float]:
122
+ return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
123
+
124
+
125
+ def get_embeddings(text, model_name, **encode_kwargs):
126
+ """
127
+ Get embeddings for the given text using the specified model.
128
+
129
+ Args:
130
+ model_name (str): Model name.
131
+ text (Union[str, List[str]]): The input text to be embedded.
132
+
133
+ Returns:
134
+ torch.Tensor: Embedding of the input text.
135
+ """
136
+ if isinstance(text, str):
137
+ text = [text]
138
+
139
+ if 'GritLM' in model_name:
140
+ emb = get_gritlm_embeddings(text, model_name, **encode_kwargs)
141
+ elif 'LLM2Vec' in model_name:
142
+ emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs)
143
+ elif 'all-mpnet-base-v2' in model_name or 'dunzhang/stella_en_1.5B_v5' in model_name:
144
+ emb = get_sentence_transformer_embeddings(text, model_name, **encode_kwargs)
145
+ else:
146
+ emb = get_api_embeddings(text, model_name, **encode_kwargs)
147
+ return emb.view(len(text), -1)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ print("Testing ModelForSTaRKQA...")
Reasoning/utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Reasoning.text_retrievers.bm25 import BM25
2
+ from Reasoning.text_retrievers.ada import Ada
3
+ from Reasoning.text_retrievers.contriever import Contriever
4
+
5
+
6
+ def combine_dicts(dicts_list, pred_dict):
7
+ if len(dicts_list) == 1:
8
+ return dicts_list[0]
9
+ combined_dict = {}
10
+
11
+ for d in dicts_list:
12
+ for key, value in d.items():
13
+ if key in combined_dict:
14
+ # for route dict, the values are lists, keep the longest list
15
+ if len(value) > len(combined_dict[key]):
16
+ combined_dict[key] = value
17
+ else:
18
+ combined_dict[key] = value
19
+
20
+ # if the two reasoning paths have intersection, only keep the keys in pred_dict
21
+ combined_dict = {key: combined_dict[key] for key in pred_dict.keys()}
22
+
23
+
24
+ return combined_dict
25
+
26
+ def fix_length(paths_dict):
27
+ max_length = 3
28
+ new_paths_dict = {}
29
+
30
+ for key, value in paths_dict.items():
31
+ if len(value) > max_length:
32
+ value = value[-max_length:]
33
+ if len(value) < max_length:
34
+ # padding with -1 at the beginning
35
+ value = [-1] * (max_length - len(value)) + value
36
+ new_paths_dict[key] = value
37
+
38
+ return new_paths_dict
39
+
40
+
41
+
42
+ def parse_metapath(metapath):
43
+ """
44
+ input: metapath: "paper -> author -> paper <- paper"
45
+ output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
46
+ """
47
+
48
+ def parse(remain_list, direction):
49
+ """
50
+ input: remain_list: ["paper", "->", "author", "->", "paper", "<-", "paper"]
51
+ direction: "->"
52
+ output: route: ["paper", "author", "paper"]
53
+ remain_list: ["paper", "<-", "paper"]
54
+ """
55
+ route = []
56
+ i = 0
57
+ while i < len(remain_list)-1 and remain_list[i+1] == direction:
58
+ route.append(remain_list[i])
59
+ i += 2
60
+ route.append(remain_list[i])
61
+
62
+ if direction == "<-":
63
+ route.reverse()
64
+
65
+ remain_list = None if len(remain_list) == i+1 else remain_list[i:]
66
+
67
+ return route, remain_list
68
+
69
+
70
+ remain_list = metapath.split(' ')
71
+ # print(f"111, {remain_list}")
72
+
73
+ if len(remain_list) == 1: # single node
74
+ return [remain_list]
75
+
76
+ routes = []
77
+ while remain_list is not None:
78
+ if remain_list[1] == "<-":
79
+ route, remain_list = parse(remain_list, "<-")
80
+
81
+ elif remain_list[1] == "->":
82
+ route, remain_list = parse(remain_list, "->")
83
+
84
+ else:
85
+ # raise ValueError(f"Invalid metapath: {metapath}")
86
+ return None
87
+
88
+ routes.append(route)
89
+
90
+ return routes
91
+
92
+
93
+ def get_text_retriever(dataset_name, retriever_name, skb, **kwargs):
94
+ if retriever_name == "bm25":
95
+ return BM25(skb, dataset_name)
96
+ elif retriever_name == "ada":
97
+ return Ada(skb, dataset_name, kwargs.get("device", 'cuda'))
98
+ elif retriever_name == "contriever":
99
+ return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
100
+ else:
101
+ raise ValueError(f"Invalid retriever name: {retriever_name}")
102
+
103
+
104
+ def get_scorer(dataset_name, scorer_name, skb, **kwargs):
105
+ if scorer_name == "bm25":
106
+ return BM25(skb, dataset_name)
107
+ elif scorer_name == "ada":
108
+ return Ada(skb, dataset_name, kwargs.get("device",'cuda'))
109
+ elif scorer_name == "contriever":
110
+ return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
111
+ else:
112
+ raise ValueError(f"Invalid scorer name: {scorer_name}")
113
+
114
+
115
+ if __name__ == "__main__":
116
+ print(f"Test utils")
Reranking/__pycache__/rerank.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
Reranking/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.23 kB). View file
 
Reranking/data/checkpoints/amazon/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4ef56bf1095b28283a5bc451556e238b31d36285df39b1a84b9c5c65f1900a2
3
+ size 804607
Reranking/data/checkpoints/mag/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b25c04fcb5fa1210fe8ca10bb7c1b3f3aad592ec0e501a1ccc330b92772550f1
3
+ size 804620
Reranking/data/checkpoints/prime/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1e4e24531067e798b975e33689cc87645a2daa9d736df1e8ac0474f72154ce0
3
+ size 804633
Reranking/rerank.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
4
+ from stark_qa import load_skb
5
+
6
+
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import torch
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ import torch.nn as nn
12
+
13
+ from Reranking.utils import move_to_cuda, seed_everything
14
+ from Reranking.rerankers.path import PathReranker
15
+ import torch.nn.functional as F
16
+ import argparse
17
+ import pickle as pkl
18
+
19
+
20
+
21
+
22
+ class TestDataset(Dataset):
23
+ """
24
+ data format: {
25
+ "query": query,
26
+ "pred_dict": {node_id: score},
27
+ 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
28
+ "text_emb_dict": {node_id: text_emb},
29
+ "ans_ids": [],
30
+ }
31
+
32
+ """
33
+
34
+ def __init__(self, saved_data, args):
35
+
36
+ print(f"Start processing test dataset...")
37
+ self.text2emb_dict = saved_data['text2emb_dict']
38
+ self.data = saved_data['data']
39
+
40
+ self.text_emb_matrix = list(self.text2emb_dict.values())
41
+ self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
42
+
43
+ # make the mapping between the key of text2emb_dict and the index of text_emb_matrix
44
+ self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
45
+
46
+ self.args = args
47
+
48
+
49
+
50
+
51
+ def __len__(self):
52
+ return len(self.data)
53
+
54
+ def __getitem__(self, idx):
55
+
56
+ if self.args.dataset_name == 'amazon':
57
+ # change from the str to index
58
+ self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
59
+ else:
60
+ # sort the pred_dict by the score
61
+ pred_dict = self.data[idx]['pred_dict']
62
+ sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
63
+ # get the top 50 candidates
64
+ sorted_ids = sorted_ids[:50]
65
+ # get the score vector
66
+ self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
67
+ # get the symb_enc_dict
68
+ self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
69
+ # change from the str to index
70
+ self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
71
+ self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
72
+
73
+
74
+
75
+ return self.data[idx]
76
+
77
+
78
+ def collate_batch(self, batch):
79
+
80
+ # q
81
+ batch_q = [batch[i]['query'] for i in range(len(batch))]
82
+ q_text = batch_q
83
+
84
+ # c
85
+ batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
86
+ batch_c = torch.tensor(batch_c)
87
+ c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
88
+ c_score_vector = torch.tensor(c_score_vector)
89
+ c_score_vector = c_score_vector[:, :, :self.args.vector_dim]
90
+
91
+ # c_symb_enc
92
+ c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
93
+ c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
94
+
95
+ # c_text_emb
96
+ c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
97
+ c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
98
+
99
+
100
+ # ans_ids
101
+ ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
102
+
103
+ # pred_ids
104
+ pred_ids = batch_c.tolist()
105
+
106
+
107
+ # Create a dictionary for the batch
108
+ feed_dict = {
109
+ 'query': q_text,
110
+ 'c_score_vector': c_score_vector,
111
+ 'c_text_emb': c_text_emb,
112
+ 'c_symb_enc': c_symb_enc,
113
+ 'ans_ids': ans_ids,
114
+ 'pred_ids': pred_ids
115
+
116
+ }
117
+
118
+
119
+ return feed_dict
120
+
121
+
122
+ # ***** batch_evaluator *****
123
+ def batch_evaluator(skb, scores_cand, ans_ids, batch):
124
+
125
+ results = {}
126
+
127
+ # **** batch wise evaluation ****
128
+ # evaluate
129
+ candidates_ids = skb.candidate_ids
130
+ id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
131
+
132
+
133
+ # initialize the pred_matrix
134
+ pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
135
+
136
+
137
+ # get the index of each pred_ids
138
+ # flatten the pred_ids
139
+ flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
140
+
141
+
142
+ # get the index of each pred_ids
143
+ pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
144
+
145
+
146
+ # reshape the pred_idx
147
+ pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
148
+
149
+ # move pred_matrix to the device
150
+ pred_matrix = pred_matrix.to(scores_cand.device)
151
+
152
+ # advanced indexing
153
+ pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
154
+
155
+
156
+ # Create a mapping from candidate IDs to their indices for faster lookup
157
+
158
+
159
+ # Flatten ans_ids to a single list and map them to indices
160
+ flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
161
+
162
+ # Create the row indices for ans_matrix corresponding to the answers
163
+ row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
164
+
165
+ # Create the answer matrix
166
+ ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
167
+ ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
168
+
169
+
170
+
171
+ # batch computing hit1
172
+ # find the index of the max score
173
+ max_score, max_idx = torch.max(pred_matrix, dim=1)
174
+ # check the label of the max idx
175
+ batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
176
+ hit1_list = batch_hit1.tolist()
177
+
178
+
179
+ # batch computing hit@5
180
+ _, top5_idx = torch.topk(pred_matrix, 5, dim=1)
181
+ batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
182
+
183
+ # max with each row
184
+ batch_hit5 = torch.max(batch_hit5, dim=1)[0]
185
+ hit5_list = batch_hit5.tolist()
186
+
187
+
188
+
189
+ # batch computing recall@20
190
+ _, top20_idx = torch.topk(pred_matrix, 20, dim=1)
191
+ batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
192
+ # sum with each row
193
+ batch_recall20 = torch.sum(batch_recall20, dim=1)
194
+ # divide by the sum of the ans_matrix along the row
195
+ batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
196
+ recall20_list = batch_recall20.tolist()
197
+
198
+
199
+
200
+ # batch computing mrr
201
+ # find the highest rank of the answer
202
+ _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
203
+ # query the answer matrix with the rank_idx
204
+ batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
205
+ # find the first rank of the answer
206
+ batch_mrr = torch.argmax(batch_mrr, dim=1)
207
+ # add 1 to the rank
208
+ batch_mrr += 1
209
+ # divide by the rank
210
+ batch_mrr = 1 / batch_mrr.float()
211
+ mrr_list = batch_mrr.tolist()
212
+
213
+
214
+ results['hit@1'] = hit1_list
215
+ results['hit@5'] = hit5_list
216
+ results['recall@20'] = recall20_list
217
+ results['mrr'] = mrr_list
218
+
219
+
220
+ return results
221
+
222
+
223
+
224
+ # ***** evaluate *****
225
+ @torch.no_grad()
226
+ def evaluate(router, test_loader, skb):
227
+
228
+
229
+ router.eval()
230
+
231
+ all_results = {
232
+ "hit@1": [],
233
+ "hit@5": [],
234
+ "recall@20": [],
235
+ "mrr": []
236
+ }
237
+ avg_results = {
238
+ "hit@1": 0,
239
+ "hit@5": 0,
240
+ "recall@20": 0,
241
+ "mrr": 0
242
+ }
243
+
244
+
245
+ # save the scores and ans_ids, and pred_ids
246
+ pred_list = []
247
+ scores_cand_list = []
248
+ ans_ids_list = []
249
+ print(f"Start evaluating...")
250
+ # use tqdm to show the progress
251
+ for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
252
+ # print(f"idx: {idx}")
253
+ batch = move_to_cuda(batch)
254
+
255
+ # Check if the model is wrapped in DataParallel
256
+ if isinstance(router, nn.DataParallel):
257
+ scores_cand = router.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
258
+ else:
259
+ scores_cand = router.eval_batch(batch)
260
+
261
+
262
+ # ans_ids
263
+ ans_ids = batch['ans_ids']
264
+
265
+ results = batch_evaluator(skb, scores_cand, ans_ids, batch)
266
+
267
+
268
+ for key in results.keys():
269
+ all_results[key].extend(results[key])
270
+
271
+ # save the scores and ans_ids, and pred_ids
272
+ pred_list.extend(batch['pred_ids'])
273
+ scores_cand_list.extend(scores_cand.cpu().tolist())
274
+ ans_ids_list.extend(ans_ids)
275
+
276
+
277
+
278
+ for key in avg_results.keys():
279
+ avg_results[key] = np.mean(all_results[key])
280
+
281
+ print(f"Results: {avg_results}")
282
+
283
+
284
+
285
+ return avg_results
286
+
287
+
288
+ def parse_args():
289
+
290
+ parser = argparse.ArgumentParser(description="Run PathRouter with dynamic combinations of embeddings.")
291
+
292
+ # dataset_name
293
+ parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset.")
294
+
295
+ # Add arguments for model configurations
296
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
297
+
298
+
299
+ # add concat_num
300
+ parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
301
+
302
+ # checkpoint save path
303
+ parser.add_argument("--checkpoint_path", type=str, default="./data/checkpoints", help="Path saves the checkpoints.")
304
+
305
+ # similarity vector dim
306
+ parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
307
+
308
+
309
+ # Parse the base arguments
310
+ args = parser.parse_args()
311
+ return args
312
+
313
+
314
+ def get_concat_num(combo):
315
+ """
316
+ Determine the value of concat_num based on the combination of embeddings.
317
+ - score_vec adds +1
318
+ - text_emb adds +1
319
+ - symb_enc adds +3
320
+ """
321
+ concat_num = 0
322
+ if combo.get("score_vec", False): # If score_vec is True
323
+ concat_num += 1
324
+ if combo.get("text_emb", False): # If text_emb is True
325
+ concat_num += 1
326
+ if combo.get("symb_enc", False): # If symb_enc is True
327
+ concat_num += 3
328
+
329
+
330
+ return concat_num
331
+
332
+
333
+ def run(test_data, skb, dataset_name):
334
+
335
+
336
+
337
+ test_size = 64
338
+ test_dataset = TestDataset(test_data, args=args)
339
+ test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
340
+
341
+ # load the model
342
+ print(f"Load the model...")
343
+ args.checkpoint_path = args.checkpoint_path + f"/{dataset_name}/best.pth"
344
+ router = PathReranker(socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, args=args)
345
+ checkpoint = torch.load(args.checkpoint_path)
346
+ router.load_state_dict(checkpoint)
347
+ router = router.to(args.device)
348
+
349
+ # evalute
350
+ test_results = evaluate(router, test_loader, skb)
351
+ print(f"Test evaluation")
352
+ print(test_results)
353
+
354
+ return test_results
355
+
356
+ if __name__ == "__main__":
357
+
358
+ combo = {
359
+ "text_emb": True,
360
+ "score_vec": True,
361
+ "symb_enc": True
362
+ }
363
+ concat_num = get_concat_num(combo)
364
+
365
+ base_args = parse_args()
366
+ args = argparse.Namespace(**vars(base_args), **combo)
367
+ args.concat_num = concat_num
368
+ dataset_name = args.dataset_name
369
+
370
+ test_data_path = f"../{dataset_name}_test.pkl"
371
+ with open(test_data_path, 'rb') as f:
372
+ test_data = pkl.load(f)
373
+ skb = load_skb(dataset_name)
374
+ results = run(test_data, skb, dataset_name)
375
+
Reranking/rerankers/__pycache__/node.cpython-311.pyc ADDED
Binary file (5.38 kB). View file
 
Reranking/rerankers/__pycache__/path.cpython-311.pyc ADDED
Binary file (5.43 kB). View file
 
Reranking/rerankers/node.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+
5
+ # ***** Reranking Model *****
6
+ class Contriever(nn.Module):
7
+ def __init__(self, encoder, emb_dim=768):
8
+ super(Contriever, self).__init__()
9
+ self.encoder = encoder
10
+ self.emb_dim = emb_dim
11
+
12
+ def mean_pooling(self, token_embs, mask):
13
+ token_embs = token_embs.masked_fill(~mask[..., None].bool(), 0.0)
14
+ sentence_embeddings = token_embs.sum(dim=1) / (mask.sum(dim=1)[..., None].clamp(min=1e-9))
15
+ return sentence_embeddings
16
+
17
+ def encode_seq(self, input_ids, attention_mask, token_type_ids=None):
18
+ # Combine inputs into a dictionary
19
+ enc = {'input_ids': input_ids, 'attention_mask': attention_mask}
20
+ if token_type_ids is not None:
21
+ enc['token_type_ids'] = token_type_ids
22
+
23
+ outputs = self.encoder(**enc)
24
+ # Mean pooling of last hidden states
25
+ embedded = self.mean_pooling(outputs[0], attention_mask)
26
+ # print(f"777, {embedded.shape}")
27
+ return embedded
28
+
29
+ def get_text_emb(self, input_ids, attention_mask, token_type_ids):
30
+ emb = self.encode_seq(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
31
+
32
+ return emb
33
+
34
+ def eval_batch(self, batch):
35
+ # q_emb: [batch_size/num_gpus, token_dim]
36
+ q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
37
+
38
+ # c_emb: [batch_size * num_candidates/num_gpus, token_dim]
39
+ c_emb = self.get_text_emb(batch['c_enc_input_ids'], batch['c_enc_attention_mask'], batch['c_enc_token_type_ids'])
40
+
41
+
42
+
43
+ return q_emb, c_emb
44
+
45
+ def forward(self, batch):
46
+
47
+ # q_emb: [batch_size/num_gpus, token_dim]
48
+ q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
49
+ # p_emb: [batch_size*max_len/num_gpus, token_dim]
50
+ p_emb = self.get_text_emb(batch['pos_enc_input_ids'], batch['pos_enc_attention_mask'], batch['pos_enc_token_type_ids'])
51
+ # n_emb: [batch_size*max_len*num_sampled_negs/num_gpus, token_dim]
52
+ n_emb = self.get_text_emb(batch['neg_enc_input_ids'], batch['neg_enc_attention_mask'], batch['neg_enc_token_type_ids'])
53
+
54
+
55
+
56
+ return q_emb, p_emb, n_emb
57
+
58
+
59
+ # ***** Reranking Model *****
60
+ class NodeRouter(nn.Module):
61
+ def __init__(self, input_dim=2, output_dim=1, emb_dim=128):
62
+ super(NodeRouter, self).__init__()
63
+ self.fc1 = nn.Linear(input_dim, emb_dim)
64
+ self.fc2 = nn.Linear(emb_dim, output_dim)
65
+ self.relu = nn.ReLU()
66
+
67
+ def eval_batch(self, batch):
68
+ scores_cand = self.fc1(batch['c_scores'])
69
+ scores_cand = self.relu(scores_cand)
70
+ scores_cand = self.fc2(scores_cand)
71
+ scores_cand = self.relu(scores_cand)
72
+
73
+ return scores_cand
74
+
75
+ def forward(self, batch):
76
+ scores_pos = self.fc1(batch['p_scores'])
77
+ scores_neg = self.fc1(batch['n_scores'])
78
+ scores_pos = self.relu(scores_pos)
79
+ scores_neg = self.relu(scores_neg)
80
+
81
+ scores_pos = self.fc2(scores_pos)
82
+ scores_neg = self.fc2(scores_neg)
83
+ scores_pos = self.relu(scores_pos)
84
+ scores_neg = self.relu(scores_neg)
85
+
86
+ return scores_pos, scores_neg
87
+
Reranking/rerankers/path.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ # ***** Reranking Model *****
5
+ # two linear layers
6
+ class PathReranker(nn.Module):
7
+ def __init__(self, socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, output_dim=1, emb_dim=256, args=None):
8
+ super(PathReranker, self).__init__()
9
+ self.score_vec_enc = nn.Linear(socre_vector_input_dim, emb_dim)
10
+ self.text_emb_enc = nn.Linear(text_emb_input_dim, emb_dim)
11
+ self.symb_enc = nn.Embedding(symb_enc_dim, emb_dim)
12
+ self.fc1 = nn.Linear(emb_dim*args.concat_num, output_dim)
13
+ self.fc2 = nn.Linear(output_dim, 1)
14
+ self.relu = nn.ReLU()
15
+ self.args = args
16
+
17
+
18
+
19
+ def eval_batch(self, batch):
20
+
21
+ embeddings = []
22
+
23
+ if self.args.text_emb:
24
+ # Encode the text embedding and apply ReLU
25
+ text_emb_c = self.relu(self.text_emb_enc(batch['c_text_emb'])) # [bs, 100, emb_dim]
26
+ embeddings.append(text_emb_c)
27
+
28
+
29
+ if self.args.score_vec:
30
+ # Encode the score vector and apply ReLU
31
+ score_vector_c = self.relu(self.score_vec_enc(batch['c_score_vector'])) # [bs, 100, emb_dim]
32
+
33
+ embeddings.append(score_vector_c)
34
+
35
+
36
+ if self.args.symb_enc:
37
+ # encode the symbolic embedding and apply ReLU
38
+ symb_enc_c = self.relu(self.symb_enc(batch['c_symb_enc']))
39
+ # reshape the symbolic embedding
40
+ symb_enc_c = torch.reshape(symb_enc_c, (symb_enc_c.shape[0], symb_enc_c.shape[1], -1))
41
+ embeddings.append(symb_enc_c)
42
+
43
+
44
+ if len(embeddings) > 1:
45
+ emb_c = torch.cat(embeddings, dim=-1)
46
+ else:
47
+ emb_c = embeddings[0]
48
+
49
+
50
+ # Feed the concatenated embeddings to the final layer
51
+ emb_c = self.fc1(emb_c) # [bs, 100, emb_dim]
52
+ scores_c = self.fc2(emb_c) # [bs, 100, 1]
53
+
54
+ return scores_c
55
+
56
+
57
+ def forward(self, batch):
58
+
59
+ embeddings_pos = []
60
+ embeddings_neg = []
61
+
62
+ if self.args.text_emb:
63
+ # Encode the text embedding and apply ReLU
64
+ text_emb_pos = self.relu(self.text_emb_enc(batch['p_text_emb']))
65
+ text_emb_neg = self.relu(self.text_emb_enc(batch['n_text_emb']))
66
+ embeddings_pos.append(text_emb_pos)
67
+ embeddings_neg.append(text_emb_neg)
68
+
69
+ if self.args.score_vec:
70
+ # Encode the score vector and apply ReLU
71
+ score_vector_pos = self.relu(self.score_vec_enc(batch['p_score_vector']))
72
+ score_vector_neg = self.relu(self.score_vec_enc(batch['n_score_vector']))
73
+ embeddings_pos.append(score_vector_pos)
74
+ embeddings_neg.append(score_vector_neg)
75
+
76
+
77
+ if self.args.symb_enc:
78
+ # encode the symbolic embedding and apply ReLU
79
+ symb_enc_pos = self.relu(self.symb_enc(batch['p_symb_enc']))
80
+ # reshape the symbolic embedding
81
+ symb_enc_pos = torch.reshape(symb_enc_pos, (symb_enc_pos.shape[0], -1))
82
+
83
+ symb_enc_neg = self.relu(self.symb_enc(batch['n_symb_enc']))
84
+ # reshape the symbolic embedding
85
+ symb_enc_neg = torch.reshape(symb_enc_neg, (symb_enc_neg.shape[0], symb_enc_neg.shape[1], -1)) # [bs, neg_sp, path_len * emb_dim]
86
+
87
+ embeddings_pos.append(symb_enc_pos)
88
+ embeddings_neg.append(symb_enc_neg)
89
+
90
+
91
+ if len(embeddings_pos) > 1:
92
+ pos = torch.cat(embeddings_pos, dim=-1)
93
+ neg = torch.cat(embeddings_neg, dim=-1)
94
+ else:
95
+ pos = embeddings_pos[0]
96
+ neg = embeddings_neg[0]
97
+
98
+
99
+
100
+ # Feed the concatenated embeddings to the final layer
101
+ pos = self.fc1(pos)
102
+ neg = self.fc1(neg)
103
+ scores_pos = self.fc2(pos)
104
+ scores_neg = self.fc2(neg)
105
+
106
+ return scores_pos, scores_neg
107
+
Reranking/train_eval_path_amazon.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
3
+
4
+ import pickle as pkl
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch
7
+ from tqdm import tqdm
8
+ import wandb
9
+ import numpy as np
10
+ import time
11
+ from torch_scatter import segment_csr, scatter_mean
12
+ from itertools import product
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
14
+ import torch.nn as nn
15
+ from torch.nn import CrossEntropyLoss
16
+ import random
17
+ from collections import defaultdict
18
+ import os
19
+
20
+ from Reranking.utils import move_to_cuda, seed_everything
21
+ from Reranking.rerankers.path import PathReranker
22
+ from utils import ModelForSTaRKQA
23
+ from stark_qa import load_qa, load_skb
24
+ import torch.nn.functional as F
25
+ import argparse
26
+ import json
27
+ import time
28
+
29
+
30
+ seed_everything(42)
31
+
32
+ # ***** Dataset *****
33
+ class TrainDataset(Dataset):
34
+ """
35
+ Custom Dataset for the training data.
36
+ Each instance contains multiple positive and negative candidates.
37
+ """
38
+ def __init__(self, saved_data, max_neg_candidates=100):
39
+ """
40
+ 10s for 1000 data
41
+ """
42
+ print(f"start processing training dataset...")
43
+ s_time = time.time()
44
+ self.max_neg_candidates = max_neg_candidates
45
+ self.sorted_query2neg = defaultdict(list)
46
+
47
+
48
+ self.text2emb_dict = saved_data['text2emb_dict']
49
+ self.data = saved_data['data']
50
+
51
+
52
+ # separage neg and pos, and prepare query, pos pairs
53
+ new_data = []
54
+
55
+ for i in range(len(self.data)):
56
+ neg_ids = []
57
+ pos_ids = []
58
+ item = self.data[i]
59
+
60
+
61
+ candidates_dict = item['pred_dict']
62
+ ans_ids = item['ans_ids']
63
+ # pos_ids = ans_ids
64
+ for ans_id in ans_ids:
65
+ if ans_id in candidates_dict.keys():
66
+ pos_ids.append(ans_id)
67
+ neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
68
+
69
+ # load scores vector
70
+ score_vector_dict = item['score_vector_dict']
71
+
72
+ # load the text path, str format
73
+ text_emb_dict = item['text_emb_dict']
74
+
75
+ # load the symb_enc_dict
76
+ symb_enc_dict = item['symb_enc_dict']
77
+
78
+
79
+ self.data[i]['pos_ids'] = pos_ids
80
+ self.data[i]['neg_ids'] = neg_ids
81
+
82
+ query = item['query']
83
+ for pos_id in pos_ids:
84
+ new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
85
+
86
+
87
+ # print(f"new_data: {new_data}")
88
+
89
+ neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
90
+ sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
91
+
92
+
93
+ self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
94
+
95
+
96
+ self.data = new_data
97
+ print(f"Complete data preparation")
98
+ print(f"Time: {time.time() - s_time}")
99
+
100
+
101
+
102
+
103
+ def __len__(self):
104
+ return len(self.data)
105
+
106
+ def __getitem__(self, idx):
107
+
108
+ return self.data[idx]
109
+
110
+ def collate_batch(self, pairs):
111
+ s_time = time.time()
112
+
113
+ # q
114
+ batch_q = [pair[0] for pair in pairs] # q is text
115
+ q_text = batch_q
116
+ # print(f"q111, {q_text}")
117
+
118
+
119
+ # pos
120
+ # get the score vector
121
+ batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
122
+ batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
123
+ batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
124
+ # get the text emb
125
+ batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
126
+ batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
127
+ # get the symb_enc
128
+ batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
129
+ batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
130
+
131
+
132
+ # Negative samples
133
+ batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
134
+
135
+
136
+ # get the score vector
137
+ batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
138
+ batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
139
+ # reshape to [bs, 100, 4]
140
+ batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
141
+ batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
142
+
143
+ # get the text emb
144
+ batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
145
+ batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
146
+ # reshape to [bs, 100, 768]
147
+ batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
148
+
149
+ # get the symb_enc
150
+ batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
151
+ batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
152
+ # reshape to [bs, 100, 3]
153
+ batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
154
+
155
+
156
+
157
+
158
+
159
+ # Create a dictionary for the batch
160
+ feed_dict = {
161
+ 'query': q_text,
162
+ 'p_score_vector': batch_p_score_vector,
163
+ 'p_text_emb': batch_p_text_emb,
164
+ 'p_symb_enc': batch_p_symb_enc,
165
+ 'n_score_vector': batch_n_score_vector,
166
+ 'n_text_emb': batch_n_text_emb,
167
+ 'n_symb_enc': batch_n_symb_enc,
168
+
169
+ }
170
+
171
+
172
+ return feed_dict
173
+
174
+
175
+ class TestDataset(Dataset):
176
+ """
177
+ data format: {
178
+ "query": query,
179
+ "pred_dict": {node_id: score},
180
+ 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
181
+ "text_emb_dict": {node_id: text_emb},
182
+ "ans_ids": [],
183
+ }
184
+
185
+ """
186
+
187
+ def __init__(self, saved_data):
188
+
189
+ print(f"Start processing test dataset...")
190
+ self.text2emb_dict = saved_data['text2emb_dict']
191
+ self.data = saved_data['data']
192
+
193
+ self.text_emb_matrix = list(self.text2emb_dict.values())
194
+ self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
195
+
196
+ # make the mapping between the key of text2emb_dict and the index of text_emb_matrix
197
+ self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
198
+
199
+ print(f"Complete data preparation: {len(self.data)}")
200
+
201
+
202
+
203
+
204
+ def __len__(self):
205
+ return len(self.data)
206
+
207
+ def __getitem__(self, idx):
208
+ # ***** amazon *****
209
+ # change from the str to index
210
+ self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
211
+
212
+ return self.data[idx]
213
+
214
+ def collate_batch(self, batch):
215
+
216
+ # q
217
+ batch_q = [batch[i]['query'] for i in range(len(batch))]
218
+ q_text = batch_q
219
+
220
+ # c
221
+ batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
222
+ batch_c = torch.tensor(batch_c)
223
+ # print(f"111, {batch_c.shape}")
224
+ c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
225
+ c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
226
+
227
+
228
+ # print(f"222, {c_vector.shape}")
229
+ # c_text_emb
230
+ # c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
231
+ c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
232
+ c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
233
+
234
+ # c_symb_enc
235
+ c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
236
+ c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
237
+
238
+
239
+ # ans_ids
240
+ ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
241
+
242
+ # pred_ids
243
+ pred_ids = batch_c.tolist()
244
+
245
+
246
+ # Create a dictionary for the batch
247
+ feed_dict = {
248
+ 'query': q_text,
249
+ 'c_score_vector': c_score_vector,
250
+ 'c_text_emb': c_text_emb,
251
+ 'c_symb_enc': c_symb_enc,
252
+ 'ans_ids': ans_ids,
253
+ 'pred_ids': pred_ids
254
+
255
+ }
256
+
257
+
258
+
259
+ return feed_dict
260
+
261
+
262
+ # ******* loss function ********
263
+ def loss_fn(scores_pos, scores_neg):
264
+
265
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
266
+
267
+ # Combine scores
268
+ scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
269
+ # print(f"scores: {scores.shape}")
270
+
271
+ # Create target
272
+ target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
273
+
274
+ # Compute loss
275
+ loss = loss_fct(scores, target)
276
+
277
+ return loss
278
+
279
+ # ***** pairwise loss *****
280
+ def pairwise_loss(scores_pos, scores_neg, margin=0.5):
281
+ # scores_pos: [bs, 1]
282
+ # scores_neg: [bs, 100, 1]
283
+
284
+ # Compute loss
285
+ differences = scores_pos.unsqueeze(1) - scores_neg - margin # [bs, 100, 1]
286
+ differences = differences.view(-1) # [bs*100]
287
+ loss = F.relu(-differences).mean() # Standard pairwise loss
288
+
289
+ return loss
290
+
291
+
292
+ def batch_evaluator(skb, scores_cand, ans_ids, batch):
293
+
294
+ results = {}
295
+
296
+ # **** batch wise evaluation ****
297
+ # evaluate
298
+ candidates_ids = skb.candidate_ids
299
+ id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
300
+
301
+
302
+ # initialize the pred_matrix
303
+ pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
304
+
305
+
306
+ # get the index of each pred_ids
307
+ # flatten the pred_ids
308
+ flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
309
+
310
+
311
+ # get the index of each pred_ids
312
+ pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
313
+
314
+
315
+ # reshape the pred_idx
316
+ pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
317
+
318
+ # move pred_matrix to the device
319
+ pred_matrix = pred_matrix.to(scores_cand.device)
320
+
321
+ # advanced indexing
322
+ pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
323
+
324
+
325
+
326
+ # Create a mapping from candidate IDs to their indices for faster lookup
327
+
328
+
329
+ # Flatten ans_ids to a single list and map them to indices
330
+ flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
331
+
332
+ # Create the row indices for ans_matrix corresponding to the answers
333
+ row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
334
+
335
+ # Create the answer matrix
336
+ ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
337
+ ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
338
+
339
+
340
+
341
+
342
+
343
+ # batch computing hit1
344
+ # find the index of the max score
345
+ max_score, max_idx = torch.max(pred_matrix, dim=1)
346
+ # check the label of the max idx
347
+ batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
348
+ hit1_list = batch_hit1.tolist()
349
+
350
+
351
+
352
+
353
+ # batch computing hit@5
354
+ _, top5_idx = torch.topk(pred_matrix, 5, dim=1)
355
+ batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
356
+
357
+ # max with each row
358
+ batch_hit5 = torch.max(batch_hit5, dim=1)[0]
359
+ hit5_list = batch_hit5.tolist()
360
+
361
+
362
+
363
+ # batch computing recall@20
364
+ _, top20_idx = torch.topk(pred_matrix, 20, dim=1)
365
+ batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
366
+ # sum with each row
367
+ batch_recall20 = torch.sum(batch_recall20, dim=1)
368
+ # divide by the sum of the ans_matrix along the row
369
+ batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
370
+ recall20_list = batch_recall20.tolist()
371
+
372
+
373
+
374
+
375
+ # batch computing mrr
376
+ # find the highest rank of the answer
377
+ _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
378
+ # query the answer matrix with the rank_idx
379
+ batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
380
+ # find the first rank of the answer
381
+ batch_mrr = torch.argmax(batch_mrr, dim=1)
382
+ # add 1 to the rank
383
+ batch_mrr += 1
384
+ # divide by the rank
385
+ batch_mrr = 1 / batch_mrr.float()
386
+ mrr_list = batch_mrr.tolist()
387
+
388
+
389
+
390
+
391
+ results['hit@1'] = hit1_list
392
+ results['hit@5'] = hit5_list
393
+ results['recall@20'] = recall20_list
394
+ results['mrr'] = mrr_list
395
+
396
+
397
+
398
+
399
+
400
+
401
+ return results
402
+
403
+
404
+
405
+ # ***** evaluate *****
406
+ @torch.no_grad()
407
+ def evaluate(reranker, test_loader):
408
+
409
+
410
+ reranker.eval()
411
+
412
+ all_results = {
413
+ "hit@1": [],
414
+ "hit@5": [],
415
+ "recall@20": [],
416
+ "mrr": []
417
+ }
418
+ avg_results = {
419
+ "hit@1": 0,
420
+ "hit@5": 0,
421
+ "recall@20": 0,
422
+ "mrr": 0
423
+ }
424
+
425
+
426
+ # save the scores and ans_ids, and pred_ids
427
+ pred_list = []
428
+ scores_cand_list = []
429
+ ans_ids_list = []
430
+ # use tqdm to show the progress
431
+ for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
432
+ batch = move_to_cuda(batch)
433
+
434
+ # Check if the model is wrapped in DataParallel
435
+ if isinstance(reranker, nn.DataParallel):
436
+ scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
437
+ else:
438
+ scores_cand = reranker.eval_batch(batch)
439
+
440
+
441
+ # ans_ids
442
+ ans_ids = batch['ans_ids']
443
+
444
+ results = batch_evaluator(skb, scores_cand, ans_ids, batch)
445
+
446
+
447
+ for key in results.keys():
448
+ all_results[key].extend(results[key])
449
+
450
+ # save the scores and ans_ids, and pred_ids
451
+ pred_list.extend(batch['pred_ids'])
452
+ scores_cand_list.extend(scores_cand.cpu().tolist())
453
+ ans_ids_list.extend(ans_ids)
454
+
455
+
456
+
457
+ for key in avg_results.keys():
458
+ avg_results[key] = np.mean(all_results[key])
459
+
460
+ print(f"Results: {avg_results}")
461
+
462
+
463
+
464
+ return avg_results
465
+
466
+
467
+ # ***** train *****
468
+ def main(train_data, val_data, test_data, skb, dataset_name, args):
469
+
470
+
471
+ epochs = args.epochs
472
+ device = args.device
473
+
474
+ train_size = args.train_batch_size
475
+ test_size = 64
476
+
477
+ train_dataset = TrainDataset(train_data)
478
+ train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
479
+
480
+ test_dataset = TestDataset(test_data)
481
+ test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
482
+
483
+ val_dataset = TestDataset(val_data)
484
+ val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
485
+
486
+
487
+ # ***** Model *****
488
+ reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
489
+ save_dir = f"./data/checkpoints/{dataset_name}/path"
490
+ os.makedirs(save_dir, exist_ok=True)
491
+
492
+ reranker.to(device)
493
+ # # parallel processing
494
+ reranker = nn.DataParallel(reranker)
495
+
496
+
497
+ optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
498
+ best_val_hit1 = float('-inf')
499
+
500
+
501
+ val_results = evaluate(reranker, val_loader)
502
+ print(f"Val evaluation")
503
+ print(val_results)
504
+
505
+
506
+ test_results = evaluate(reranker, test_loader)
507
+ print(f"Test evaluation")
508
+ print(test_results)
509
+
510
+ # log both val and test results
511
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
512
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
513
+
514
+ best_test_results = {}
515
+ for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
516
+ total_loss = 0.0
517
+ reranker.train()
518
+ count = 0
519
+ total_instances = 0
520
+
521
+ for batch in tqdm(train_loader):
522
+ # print(batch)
523
+ batch = move_to_cuda(batch)
524
+ # print(batch)
525
+
526
+ scores_pos, scores_neg = reranker(batch)
527
+
528
+ # batch_loss = pairwise_loss(scores_pos, scores_neg)
529
+ batch_loss = loss_fn(scores_pos, scores_neg)
530
+
531
+ # clear optimizer
532
+ optimizer.zero_grad()
533
+ batch_loss.backward()
534
+ optimizer.step()
535
+
536
+ # total_loss += batch_loss.item()
537
+ count += 1
538
+ # compute the average loss
539
+ total_instances += scores_pos.shape[0]
540
+ total_loss += batch_loss.item()
541
+
542
+
543
+ train_loss = total_loss / total_instances
544
+
545
+ print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
546
+
547
+
548
+
549
+ val_results = evaluate(reranker, val_loader)
550
+ print(f"Val evaluation")
551
+ print(val_results)
552
+
553
+
554
+ test_results = evaluate(reranker, test_loader)
555
+ print(f"Test evaluation")
556
+ print(test_results)
557
+
558
+ # log both val and test results
559
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
560
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
561
+ 'train_loss': train_loss})
562
+
563
+
564
+ # save the best model when val hit1 is the highest
565
+ hit1 = val_results['hit@1']
566
+ if best_val_hit1 < hit1:
567
+ best_val_hit1 = hit1
568
+
569
+ save_path = f"{save_dir}/best_{best_val_hit1}.pth"
570
+
571
+ if isinstance(reranker, nn.DataParallel):
572
+ torch.save(reranker.module.state_dict(), save_path)
573
+ else:
574
+ torch.save(reranker.state_dict(), save_path)
575
+ print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
576
+
577
+ args.checkpoint_path = save_path
578
+ best_test_results = test_results
579
+
580
+
581
+
582
+ # save last epoch checkopint
583
+ save_path = f"{save_dir}/last_{hit1}.pth"
584
+ if isinstance(reranker, nn.DataParallel):
585
+ torch.save(reranker.module.state_dict(), save_path)
586
+ else:
587
+ torch.save(reranker.state_dict(), save_path)
588
+ print(f"Final checkpoint saved at {save_path}")
589
+
590
+
591
+
592
+ # ***** save the results *****
593
+ results = []
594
+ results.append(
595
+ {
596
+ "config": vars(args),
597
+ "test_results": best_test_results
598
+ }
599
+ )
600
+ # save the results to json
601
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
602
+ output_dir = f"./data/outputs/{dataset_name}"
603
+ os.makedirs(output_dir, exist_ok=True)
604
+ with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
605
+ json.dump(results, f, indent=4)
606
+
607
+ print(best_test_results)
608
+
609
+
610
+ def get_concat_num(combo):
611
+ """
612
+ Determine the value of concat_num based on the combination of embeddings.
613
+ - score_vec adds +1
614
+ - text_emb adds +1
615
+ - symb_enc adds +3
616
+ """
617
+ concat_num = 0
618
+ if combo.get("score_vec", False): # If score_vec is True
619
+ concat_num += 1
620
+ if combo.get("text_emb", False): # If text_emb is True
621
+ concat_num += 1
622
+ if combo.get("symb_enc", False): # If symb_enc is True
623
+ concat_num += 3
624
+
625
+
626
+ return concat_num
627
+
628
+
629
+
630
+ def parse_args():
631
+
632
+ parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
633
+
634
+ # Add arguments for model configurations
635
+ parser.add_argument("--train_batch_size", type=int, default=64, help="Batch size for training or evaluation.")
636
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for optimizer.")
637
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
638
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
639
+
640
+ # Add arguments for the dataset
641
+ parser.add_argument("--dataset_name", type=str, default="amazon", help="Name of the dataset to use.")
642
+ # paths
643
+ parser.add_argument("--train_path", type=str, default=f"../amazon_train.pkl", help="Path to the training data.")
644
+ parser.add_argument("--test_path", type=str, default=f"../amazon_test.pkl", help="Path to the test data.")
645
+ parser.add_argument("--val_path", type=str, default=f"../amazon_val.pkl", help="Path to the validation data.")
646
+
647
+ # add concat_num
648
+ parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
649
+
650
+ # checkpoint save path
651
+ parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
652
+ parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
653
+
654
+
655
+ # Parse the base arguments
656
+ args = parser.parse_args()
657
+ return args
658
+
659
+
660
+ if __name__ == "__main__":
661
+
662
+ base_args = parse_args()
663
+ test_path = base_args.test_path
664
+ train_path = base_args.train_path
665
+ val_path = base_args.val_path
666
+ dataset_name = base_args.dataset_name
667
+
668
+ with open(test_path, "rb") as f:
669
+ test_data = pkl.load(f)
670
+
671
+ with open(train_path, "rb") as f:
672
+ train_data = pkl.load(f)
673
+
674
+ with open(val_path, "rb") as f:
675
+ val_data = pkl.load(f)
676
+
677
+ # load skb
678
+ skb = load_skb(dataset_name)
679
+
680
+ # set all
681
+ combo = {
682
+ "text_emb": True,
683
+ "score_vec": True,
684
+ "symb_enc": True
685
+ }
686
+ concat_num = get_concat_num(combo)
687
+
688
+ wandb.init(project=f'reranking-{dataset_name}', name=f"path")
689
+ args = argparse.Namespace(**vars(base_args), **combo)
690
+ args.concat_num = concat_num
691
+
692
+
693
+ main(train_data, val_data, test_data, skb, dataset_name, args)
694
+
Reranking/train_eval_path_mag.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
3
+
4
+ import pickle as pkl
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch
7
+ from tqdm import tqdm
8
+ import wandb
9
+ import numpy as np
10
+ import time
11
+ from torch_scatter import segment_csr, scatter_mean
12
+ from itertools import product
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
14
+ import torch.nn as nn
15
+ from torch.nn import CrossEntropyLoss
16
+ import random
17
+ from collections import defaultdict
18
+ import os
19
+
20
+ from Reranking.utils import move_to_cuda, seed_everything
21
+ from Reranking.rerankers.path import PathReranker
22
+ from utils import ModelForSTaRKQA
23
+ from stark_qa import load_qa, load_skb
24
+ import torch.nn.functional as F
25
+ import argparse
26
+ import json
27
+ import time
28
+
29
+ seed_everything(42)
30
+
31
+ # ***** Dataset *****
32
+ class TrainDataset(Dataset):
33
+ """
34
+ Custom Dataset for the training data.
35
+ Each instance contains multiple positive and negative candidates.
36
+ """
37
+ def __init__(self, saved_data, max_neg_candidates=100):
38
+ """
39
+ 10s for 1000 data
40
+ """
41
+ print(f"start processing training dataset...")
42
+ s_time = time.time()
43
+ self.max_neg_candidates = max_neg_candidates
44
+ self.sorted_query2neg = defaultdict(list)
45
+
46
+
47
+ self.text2emb_dict = saved_data['text2emb_dict']
48
+ self.data = saved_data['data']
49
+
50
+
51
+ # separage neg and pos, and prepare query, pos pairs
52
+ new_data = []
53
+
54
+ for i in range(len(self.data)):
55
+ neg_ids = []
56
+ pos_ids = []
57
+ item = self.data[i]
58
+
59
+
60
+ candidates_dict = item['pred_dict']
61
+ ans_ids = item['ans_ids']
62
+ # pos_ids = ans_ids
63
+ for ans_id in ans_ids:
64
+ if ans_id in candidates_dict.keys():
65
+ pos_ids.append(ans_id)
66
+ neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
67
+
68
+ # load scores vector
69
+ score_vector_dict = item['score_vector_dict']
70
+
71
+ # load the text path, str format
72
+ text_emb_dict = item['text_emb_dict']
73
+
74
+ # load the symb_enc_dict
75
+ symb_enc_dict = item['symb_enc_dict']
76
+
77
+
78
+ self.data[i]['pos_ids'] = pos_ids
79
+ self.data[i]['neg_ids'] = neg_ids
80
+
81
+ query = item['query']
82
+ for pos_id in pos_ids:
83
+ new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
84
+
85
+
86
+ # print(f"new_data: {new_data}")
87
+
88
+ neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
89
+ sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
90
+
91
+
92
+ self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
93
+
94
+
95
+ self.data = new_data
96
+ print(f"Complete data preparation")
97
+ print(f"Time: {time.time() - s_time}")
98
+
99
+
100
+
101
+
102
+ def __len__(self):
103
+ return len(self.data)
104
+
105
+ def __getitem__(self, idx):
106
+
107
+ return self.data[idx]
108
+
109
+ def collate_batch(self, pairs):
110
+ s_time = time.time()
111
+
112
+ # q
113
+ batch_q = [pair[0] for pair in pairs] # q is text
114
+ q_text = batch_q
115
+ # print(f"q111, {q_text}")
116
+
117
+
118
+ # pos
119
+ # get the score vector
120
+ batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
121
+ batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
122
+ batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
123
+ # get the text emb
124
+ batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
125
+ batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
126
+ # get the symb_enc
127
+ batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
128
+ batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
129
+
130
+
131
+ # Negative samples
132
+ batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
133
+
134
+
135
+ # get the score vector
136
+ batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
137
+ batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
138
+ # reshape to [bs, 100, 4]
139
+ batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
140
+ batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
141
+
142
+ # get the text emb
143
+ batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
144
+ batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
145
+ # reshape to [bs, 100, 768]
146
+ batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
147
+
148
+ # get the symb_enc
149
+ batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
150
+ batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
151
+ # reshape to [bs, 100, 3]
152
+ batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
153
+
154
+
155
+
156
+
157
+
158
+ # Create a dictionary for the batch
159
+ feed_dict = {
160
+ 'query': q_text,
161
+ 'p_score_vector': batch_p_score_vector,
162
+ 'p_text_emb': batch_p_text_emb,
163
+ 'p_symb_enc': batch_p_symb_enc,
164
+ 'n_score_vector': batch_n_score_vector,
165
+ 'n_text_emb': batch_n_text_emb,
166
+ 'n_symb_enc': batch_n_symb_enc,
167
+
168
+ }
169
+
170
+
171
+ return feed_dict
172
+
173
+
174
+ class TestDataset(Dataset):
175
+ """
176
+ data format: {
177
+ "query": query,
178
+ "pred_dict": {node_id: score},
179
+ 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
180
+ "text_emb_dict": {node_id: text_emb},
181
+ "ans_ids": [],
182
+ }
183
+
184
+ """
185
+
186
+ def __init__(self, saved_data):
187
+
188
+ print(f"Start processing test dataset...")
189
+ self.text2emb_dict = saved_data['text2emb_dict']
190
+ self.data = saved_data['data']
191
+
192
+ self.text_emb_matrix = list(self.text2emb_dict.values())
193
+ self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
194
+
195
+ # make the mapping between the key of text2emb_dict and the index of text_emb_matrix
196
+ self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
197
+
198
+ print(f"Complete data preparation: {len(self.data)}")
199
+
200
+
201
+
202
+
203
+ def __len__(self):
204
+ return len(self.data)
205
+
206
+ def __getitem__(self, idx):
207
+
208
+
209
+ # sort the pred_dict by the score
210
+ pred_dict = self.data[idx]['pred_dict']
211
+ sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
212
+ # get the top 50 candidates
213
+ sorted_ids = sorted_ids[:50]
214
+ # get the score vector
215
+ self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
216
+ # get the symb_enc_dict
217
+ self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
218
+ # change from the str to index
219
+ self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
220
+ self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
221
+
222
+ return self.data[idx]
223
+
224
+ def collate_batch(self, batch):
225
+
226
+ # q
227
+ batch_q = [batch[i]['query'] for i in range(len(batch))]
228
+ q_text = batch_q
229
+
230
+ # c
231
+ batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
232
+ batch_c = torch.tensor(batch_c)
233
+ # print(f"111, {batch_c.shape}")
234
+ c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
235
+ c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
236
+
237
+
238
+ # print(f"222, {c_vector.shape}")
239
+ # c_text_emb
240
+ # c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
241
+ c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
242
+ c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
243
+
244
+ # c_symb_enc
245
+ c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
246
+ c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
247
+
248
+
249
+ # ans_ids
250
+ ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
251
+
252
+ # pred_ids
253
+ pred_ids = batch_c.tolist()
254
+
255
+
256
+ # Create a dictionary for the batch
257
+ feed_dict = {
258
+ 'query': q_text,
259
+ 'c_score_vector': c_score_vector,
260
+ 'c_text_emb': c_text_emb,
261
+ 'c_symb_enc': c_symb_enc,
262
+ 'ans_ids': ans_ids,
263
+ 'pred_ids': pred_ids
264
+
265
+ }
266
+
267
+
268
+
269
+ return feed_dict
270
+
271
+
272
+ # ******* loss function ********
273
+ def loss_fn(scores_pos, scores_neg):
274
+
275
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
276
+
277
+ # Combine scores
278
+ scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
279
+ # print(f"scores: {scores.shape}")
280
+
281
+ # Create target
282
+ target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
283
+
284
+ # Compute loss
285
+ loss = loss_fct(scores, target)
286
+
287
+ return loss
288
+
289
+
290
+
291
+ def batch_evaluator(skb, scores_cand, ans_ids, batch):
292
+
293
+ results = {}
294
+
295
+ # **** batch wise evaluation ****
296
+ # evaluate
297
+ candidates_ids = skb.candidate_ids
298
+ id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
299
+
300
+
301
+ # initialize the pred_matrix
302
+ pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
303
+
304
+
305
+ # get the index of each pred_ids
306
+ # flatten the pred_ids
307
+ flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
308
+
309
+
310
+ # get the index of each pred_ids
311
+ pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
312
+
313
+
314
+ # reshape the pred_idx
315
+ pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
316
+
317
+ # move pred_matrix to the device
318
+ pred_matrix = pred_matrix.to(scores_cand.device)
319
+
320
+ # advanced indexing
321
+ pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
322
+
323
+
324
+
325
+ # Create a mapping from candidate IDs to their indices for faster lookup
326
+
327
+
328
+ # Flatten ans_ids to a single list and map them to indices
329
+ flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
330
+
331
+ # Create the row indices for ans_matrix corresponding to the answers
332
+ row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
333
+
334
+ # Create the answer matrix
335
+ ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
336
+ ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
337
+
338
+
339
+
340
+
341
+
342
+ # batch computing hit1
343
+ # find the index of the max score
344
+ max_score, max_idx = torch.max(pred_matrix, dim=1)
345
+ # check the label of the max idx
346
+ batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
347
+ hit1_list = batch_hit1.tolist()
348
+
349
+
350
+
351
+
352
+ # batch computing hit@5
353
+ _, top5_idx = torch.topk(pred_matrix, 5, dim=1)
354
+ batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
355
+
356
+ # max with each row
357
+ batch_hit5 = torch.max(batch_hit5, dim=1)[0]
358
+ hit5_list = batch_hit5.tolist()
359
+
360
+
361
+
362
+ # batch computing recall@20
363
+ _, top20_idx = torch.topk(pred_matrix, 20, dim=1)
364
+ batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
365
+ # sum with each row
366
+ batch_recall20 = torch.sum(batch_recall20, dim=1)
367
+ # divide by the sum of the ans_matrix along the row
368
+ batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
369
+ recall20_list = batch_recall20.tolist()
370
+
371
+
372
+
373
+
374
+ # batch computing mrr
375
+ # find the highest rank of the answer
376
+ _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
377
+ # query the answer matrix with the rank_idx
378
+ batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
379
+ # find the first rank of the answer
380
+ batch_mrr = torch.argmax(batch_mrr, dim=1)
381
+ # add 1 to the rank
382
+ batch_mrr += 1
383
+ # divide by the rank
384
+ batch_mrr = 1 / batch_mrr.float()
385
+ mrr_list = batch_mrr.tolist()
386
+
387
+
388
+
389
+
390
+ results['hit@1'] = hit1_list
391
+ results['hit@5'] = hit5_list
392
+ results['recall@20'] = recall20_list
393
+ results['mrr'] = mrr_list
394
+
395
+
396
+
397
+
398
+
399
+
400
+ return results
401
+
402
+
403
+
404
+ # ***** evaluate *****
405
+ @torch.no_grad()
406
+ def evaluate(reranker, test_loader):
407
+
408
+
409
+ reranker.eval()
410
+
411
+ all_results = {
412
+ "hit@1": [],
413
+ "hit@5": [],
414
+ "recall@20": [],
415
+ "mrr": []
416
+ }
417
+ avg_results = {
418
+ "hit@1": 0,
419
+ "hit@5": 0,
420
+ "recall@20": 0,
421
+ "mrr": 0
422
+ }
423
+
424
+
425
+ # save the scores and ans_ids, and pred_ids
426
+ pred_list = []
427
+ scores_cand_list = []
428
+ ans_ids_list = []
429
+ # use tqdm to show the progress
430
+ for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
431
+ batch = move_to_cuda(batch)
432
+
433
+ # Check if the model is wrapped in DataParallel
434
+ if isinstance(reranker, nn.DataParallel):
435
+ scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
436
+ else:
437
+ scores_cand = reranker.eval_batch(batch)
438
+
439
+
440
+ # ans_ids
441
+ ans_ids = batch['ans_ids']
442
+
443
+ results = batch_evaluator(skb, scores_cand, ans_ids, batch)
444
+
445
+
446
+ for key in results.keys():
447
+ all_results[key].extend(results[key])
448
+
449
+ # save the scores and ans_ids, and pred_ids
450
+ pred_list.extend(batch['pred_ids'])
451
+ scores_cand_list.extend(scores_cand.cpu().tolist())
452
+ ans_ids_list.extend(ans_ids)
453
+
454
+
455
+
456
+ for key in avg_results.keys():
457
+ avg_results[key] = np.mean(all_results[key])
458
+
459
+ print(f"Results: {avg_results}")
460
+
461
+
462
+
463
+ return avg_results
464
+
465
+
466
+ # ***** train *****
467
+ def main(train_data, val_data, test_data, skb, dataset_name, args):
468
+
469
+
470
+ epochs = args.epochs
471
+ device = args.device
472
+
473
+ train_size = args.train_batch_size
474
+ test_size = 64
475
+
476
+ train_dataset = TrainDataset(train_data)
477
+ train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
478
+
479
+ test_dataset = TestDataset(test_data)
480
+ test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
481
+
482
+ val_dataset = TestDataset(val_data)
483
+ val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
484
+
485
+
486
+ # ***** Model *****
487
+ reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
488
+ save_dir = f"./data/checkpoints/{dataset_name}/path"
489
+ os.makedirs(save_dir, exist_ok=True)
490
+
491
+ reranker.to(device)
492
+ # # parallel processing
493
+ reranker = nn.DataParallel(reranker)
494
+
495
+
496
+ optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
497
+ best_val_hit1 = float('-inf')
498
+
499
+
500
+ val_results = evaluate(reranker, val_loader)
501
+ print(f"Val evaluation")
502
+ print(val_results)
503
+
504
+
505
+ test_results = evaluate(reranker, test_loader)
506
+ print(f"Test evaluation")
507
+ print(test_results)
508
+
509
+ # log both val and test results
510
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
511
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
512
+
513
+ best_test_results = {}
514
+ for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
515
+ total_loss = 0.0
516
+ reranker.train()
517
+ count = 0
518
+ total_instances = 0
519
+
520
+ for batch in tqdm(train_loader):
521
+ # print(batch)
522
+ batch = move_to_cuda(batch)
523
+ # print(batch)
524
+
525
+ scores_pos, scores_neg = reranker(batch)
526
+
527
+ # batch_loss = pairwise_loss(scores_pos, scores_neg)
528
+ batch_loss = loss_fn(scores_pos, scores_neg)
529
+
530
+ # clear optimizer
531
+ optimizer.zero_grad()
532
+ batch_loss.backward()
533
+ optimizer.step()
534
+
535
+ # total_loss += batch_loss.item()
536
+ count += 1
537
+ # compute the average loss
538
+ total_instances += scores_pos.shape[0]
539
+ total_loss += batch_loss.item()
540
+
541
+
542
+ train_loss = total_loss / total_instances
543
+
544
+ print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
545
+
546
+
547
+
548
+ val_results = evaluate(reranker, val_loader)
549
+ print(f"Val evaluation")
550
+ print(val_results)
551
+
552
+
553
+ test_results = evaluate(reranker, test_loader)
554
+ print(f"Test evaluation")
555
+ print(test_results)
556
+
557
+ # log both val and test results
558
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
559
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
560
+ 'train_loss': train_loss})
561
+
562
+
563
+ # save the best model when val hit1 is the highest
564
+ hit1 = val_results['hit@1']
565
+ if best_val_hit1 < hit1:
566
+ best_val_hit1 = hit1
567
+
568
+ save_path = f"{save_dir}/best_{best_val_hit1}.pth"
569
+
570
+ if isinstance(reranker, nn.DataParallel):
571
+ torch.save(reranker.module.state_dict(), save_path)
572
+ else:
573
+ torch.save(reranker.state_dict(), save_path)
574
+ print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
575
+
576
+ args.checkpoint_path = save_path
577
+ best_test_results = test_results
578
+
579
+
580
+
581
+ # save last epoch checkopint
582
+ save_path = f"{save_dir}/last_{hit1}.pth"
583
+ if isinstance(reranker, nn.DataParallel):
584
+ torch.save(reranker.module.state_dict(), save_path)
585
+ else:
586
+ torch.save(reranker.state_dict(), save_path)
587
+ print(f"Final checkpoint saved at {save_path}")
588
+
589
+
590
+
591
+ # ***** save the results *****
592
+ results = []
593
+ results.append(
594
+ {
595
+ "config": vars(args),
596
+ "test_results": best_test_results
597
+ }
598
+ )
599
+ # save the results to json
600
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
601
+ output_dir = f"./data/outputs/{dataset_name}"
602
+ os.makedirs(output_dir, exist_ok=True)
603
+ with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
604
+ json.dump(results, f, indent=4)
605
+
606
+ print(best_test_results)
607
+
608
+
609
+
610
+ def get_concat_num(combo):
611
+ """
612
+ Determine the value of concat_num based on the combination of embeddings.
613
+ - score_vec adds +1
614
+ - text_emb adds +1
615
+ - symb_enc adds +3
616
+ """
617
+ concat_num = 0
618
+ if combo.get("score_vec", False): # If score_vec is True
619
+ concat_num += 1
620
+ if combo.get("text_emb", False): # If text_emb is True
621
+ concat_num += 1
622
+ if combo.get("symb_enc", False): # If symb_enc is True
623
+ concat_num += 3
624
+
625
+
626
+ return concat_num
627
+
628
+
629
+
630
+ def parse_args():
631
+
632
+ parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
633
+
634
+ # Add arguments for model configurations
635
+ parser.add_argument("--train_batch_size", type=int, default=5, help="Batch size for training or evaluation.")
636
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for optimizer.")
637
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
638
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
639
+
640
+ # Add arguments for the dataset
641
+ parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset to use.")
642
+ # paths
643
+ parser.add_argument("--train_path", type=str, default=f"../mag_train.pkl", help="Path to the training data.")
644
+ parser.add_argument("--test_path", type=str, default=f"../mag_test.pkl", help="Path to the test data.")
645
+ parser.add_argument("--val_path", type=str, default=f"../mag_val.pkl", help="Path to the validation data.")
646
+
647
+ # add concat_num
648
+ parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
649
+
650
+ # checkpoint save path
651
+ parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
652
+ parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
653
+
654
+
655
+ # Parse the base arguments
656
+ args = parser.parse_args()
657
+ return args
658
+
659
+
660
+ if __name__ == "__main__":
661
+
662
+ base_args = parse_args()
663
+
664
+ test_path = base_args.test_path
665
+ train_path = base_args.train_path
666
+ val_path = base_args.val_path
667
+ dataset_name = base_args.dataset_name
668
+
669
+ with open(test_path, "rb") as f:
670
+ test_data = pkl.load(f)
671
+
672
+ with open(train_path, "rb") as f:
673
+ train_data = pkl.load(f)
674
+
675
+ with open(val_path, "rb") as f:
676
+ val_data = pkl.load(f)
677
+
678
+ # load skb
679
+ skb = load_skb(dataset_name)
680
+
681
+ # set all
682
+ combo = {
683
+ "text_emb": True,
684
+ "score_vec": True,
685
+ "symb_enc": True
686
+ }
687
+ concat_num = get_concat_num(combo)
688
+
689
+ wandb.init(project=f'Reranking-{dataset_name}', name=f"path")
690
+ args = argparse.Namespace(**vars(base_args), **combo)
691
+ args.concat_num = concat_num
692
+
693
+ main(train_data, val_data, test_data, skb, dataset_name, args)
694
+
Reranking/train_eval_path_prime.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
4
+
5
+ import pickle as pkl
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import torch
8
+ from tqdm import tqdm
9
+ import wandb
10
+ import numpy as np
11
+ import time
12
+ import torch.nn as nn
13
+ from torch.nn import CrossEntropyLoss
14
+ import random
15
+ from collections import defaultdict
16
+
17
+
18
+ from Reranking.utils import move_to_cuda, seed_everything
19
+ from Reranking.rerankers.path import PathReranker
20
+ from stark_qa import load_qa, load_skb
21
+ import torch.nn.functional as F
22
+ import argparse
23
+ import json
24
+ import time
25
+
26
+ # set the seed
27
+ seed_everything(42)
28
+
29
+ # ***** Dataset *****
30
+ class TrainDataset(Dataset):
31
+ """
32
+ Custom Dataset for the training data.
33
+ Each instance contains multiple positive and negative candidates.
34
+ """
35
+ def __init__(self, saved_data, max_neg_candidates=100):
36
+ """
37
+ 10s for 1000 data
38
+ """
39
+ print(f"start processing training dataset...")
40
+ s_time = time.time()
41
+ self.max_neg_candidates = max_neg_candidates
42
+ self.sorted_query2neg = defaultdict(list)
43
+
44
+
45
+ self.text2emb_dict = saved_data['text2emb_dict']
46
+ self.data = saved_data['data']
47
+
48
+
49
+ # separage neg and pos, and prepare query, pos pairs
50
+ new_data = []
51
+
52
+ for i in range(len(self.data)):
53
+ neg_ids = []
54
+ pos_ids = []
55
+ item = self.data[i]
56
+
57
+
58
+ candidates_dict = item['pred_dict']
59
+ ans_ids = item['ans_ids']
60
+ # pos_ids = ans_ids
61
+ for ans_id in ans_ids:
62
+ if ans_id in candidates_dict.keys():
63
+ pos_ids.append(ans_id)
64
+ neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
65
+
66
+ # load scores vector
67
+ score_vector_dict = item['score_vector_dict']
68
+
69
+ # load the text path, str format
70
+ text_emb_dict = item['text_emb_dict']
71
+
72
+ # load the symb_enc_dict
73
+ symb_enc_dict = item['symb_enc_dict']
74
+
75
+
76
+ self.data[i]['pos_ids'] = pos_ids
77
+ self.data[i]['neg_ids'] = neg_ids
78
+
79
+ query = item['query']
80
+ for pos_id in pos_ids:
81
+ new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
82
+
83
+
84
+ # print(f"new_data: {new_data}")
85
+
86
+ neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
87
+ sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
88
+
89
+
90
+ self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
91
+
92
+
93
+ self.data = new_data
94
+ print(f"Complete data preparation")
95
+ print(f"Time: {time.time() - s_time}")
96
+
97
+
98
+
99
+
100
+ def __len__(self):
101
+ return len(self.data)
102
+
103
+ def __getitem__(self, idx):
104
+
105
+ return self.data[idx]
106
+
107
+ def collate_batch(self, pairs):
108
+ s_time = time.time()
109
+
110
+ # q
111
+ batch_q = [pair[0] for pair in pairs] # q is text
112
+ q_text = batch_q
113
+ # print(f"q111, {q_text}")
114
+
115
+
116
+ # pos
117
+ # get the score vector
118
+ batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
119
+ batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
120
+ batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
121
+ # get the text emb
122
+ batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
123
+ batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
124
+ # get the symb_enc
125
+ batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
126
+ batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
127
+
128
+
129
+ # Negative samples
130
+ batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
131
+
132
+
133
+ # get the score vector
134
+ batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
135
+ batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
136
+ # reshape to [bs, 100, 4]
137
+ batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
138
+ batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
139
+
140
+ # get the text emb
141
+ batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
142
+ batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
143
+ # reshape to [bs, 100, 768]
144
+ batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
145
+
146
+ # get the symb_enc
147
+ batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
148
+ batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
149
+ # reshape to [bs, 100, 3]
150
+ batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
151
+
152
+
153
+
154
+
155
+
156
+ # Create a dictionary for the batch
157
+ feed_dict = {
158
+ 'query': q_text,
159
+ 'p_score_vector': batch_p_score_vector,
160
+ 'p_text_emb': batch_p_text_emb,
161
+ 'p_symb_enc': batch_p_symb_enc,
162
+ 'n_score_vector': batch_n_score_vector,
163
+ 'n_text_emb': batch_n_text_emb,
164
+ 'n_symb_enc': batch_n_symb_enc,
165
+
166
+ }
167
+
168
+
169
+ return feed_dict
170
+
171
+
172
+ class TestDataset(Dataset):
173
+ """
174
+ data format: {
175
+ "query": query,
176
+ "pred_dict": {node_id: score},
177
+ 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
178
+ "text_emb_dict": {node_id: text_emb},
179
+ "ans_ids": [],
180
+ }
181
+
182
+ """
183
+
184
+ def __init__(self, saved_data):
185
+
186
+ print(f"Start processing test dataset...")
187
+ self.text2emb_dict = saved_data['text2emb_dict']
188
+ self.data = saved_data['data']
189
+
190
+ self.text_emb_matrix = list(self.text2emb_dict.values())
191
+ self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
192
+
193
+ # make the mapping between the key of text2emb_dict and the index of text_emb_matrix
194
+ self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
195
+
196
+ print(f"Complete data preparation: {len(self.data)}")
197
+
198
+
199
+
200
+
201
+ def __len__(self):
202
+ return len(self.data)
203
+
204
+ def __getitem__(self, idx):
205
+
206
+ # sort the pred_dict by the score
207
+ pred_dict = self.data[idx]['pred_dict']
208
+ sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
209
+ # get the top 50 candidates
210
+ sorted_ids = sorted_ids[:50]
211
+ # get the score vector
212
+ self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
213
+ # get the symb_enc_dict
214
+ self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
215
+ # change from the str to index
216
+ self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
217
+ self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
218
+
219
+ return self.data[idx]
220
+
221
+ def collate_batch(self, batch):
222
+
223
+ # q
224
+ batch_q = [batch[i]['query'] for i in range(len(batch))]
225
+ q_text = batch_q
226
+
227
+ # c
228
+ batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
229
+ batch_c = torch.tensor(batch_c)
230
+ # print(f"111, {batch_c.shape}")
231
+ c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
232
+ c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
233
+
234
+
235
+ # print(f"222, {c_vector.shape}")
236
+ # c_text_emb
237
+ # c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
238
+ c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
239
+ c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
240
+
241
+ # c_symb_enc
242
+ c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
243
+ c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
244
+
245
+
246
+ # ans_ids
247
+ ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
248
+
249
+ # pred_ids
250
+ pred_ids = batch_c.tolist()
251
+
252
+
253
+ # Create a dictionary for the batch
254
+ feed_dict = {
255
+ 'query': q_text,
256
+ 'c_score_vector': c_score_vector,
257
+ 'c_text_emb': c_text_emb,
258
+ 'c_symb_enc': c_symb_enc,
259
+ 'ans_ids': ans_ids,
260
+ 'pred_ids': pred_ids
261
+
262
+ }
263
+
264
+
265
+
266
+ return feed_dict
267
+
268
+
269
+ # ******* loss function ********
270
+ def loss_fn(scores_pos, scores_neg):
271
+
272
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
273
+
274
+ # Combine scores
275
+ scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
276
+ # print(f"scores: {scores.shape}")
277
+
278
+ # Create target
279
+ target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
280
+
281
+ # Compute loss
282
+ loss = loss_fct(scores, target)
283
+
284
+ return loss
285
+
286
+
287
+
288
+ def batch_evaluator(skb, scores_cand, ans_ids, batch):
289
+
290
+ results = {}
291
+
292
+ # **** batch wise evaluation ****
293
+ # evaluate
294
+ candidates_ids = skb.candidate_ids
295
+ id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
296
+
297
+
298
+ # initialize the pred_matrix
299
+ pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
300
+
301
+
302
+ # get the index of each pred_ids
303
+ # flatten the pred_ids
304
+ flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
305
+
306
+
307
+ # get the index of each pred_ids
308
+ pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
309
+
310
+
311
+ # reshape the pred_idx
312
+ pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
313
+
314
+ # move pred_matrix to the device
315
+ pred_matrix = pred_matrix.to(scores_cand.device)
316
+
317
+ # advanced indexing
318
+ pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
319
+
320
+
321
+
322
+ # Create a mapping from candidate IDs to their indices for faster lookup
323
+
324
+
325
+ # Flatten ans_ids to a single list and map them to indices
326
+ flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
327
+
328
+ # Create the row indices for ans_matrix corresponding to the answers
329
+ row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
330
+
331
+ # Create the answer matrix
332
+ ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
333
+ ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
334
+
335
+
336
+
337
+
338
+
339
+ # batch computing hit1
340
+ # find the index of the max score
341
+ max_score, max_idx = torch.max(pred_matrix, dim=1)
342
+ # check the label of the max idx
343
+ batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
344
+ hit1_list = batch_hit1.tolist()
345
+
346
+
347
+
348
+
349
+ # batch computing hit@5
350
+ _, top5_idx = torch.topk(pred_matrix, 5, dim=1)
351
+ batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
352
+
353
+ # max with each row
354
+ batch_hit5 = torch.max(batch_hit5, dim=1)[0]
355
+ hit5_list = batch_hit5.tolist()
356
+
357
+
358
+
359
+ # batch computing recall@20
360
+ _, top20_idx = torch.topk(pred_matrix, 20, dim=1)
361
+ batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
362
+ # sum with each row
363
+ batch_recall20 = torch.sum(batch_recall20, dim=1)
364
+ # divide by the sum of the ans_matrix along the row
365
+ batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
366
+ recall20_list = batch_recall20.tolist()
367
+
368
+
369
+
370
+
371
+ # batch computing mrr
372
+ # find the highest rank of the answer
373
+ _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
374
+ # query the answer matrix with the rank_idx
375
+ batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
376
+ # find the first rank of the answer
377
+ batch_mrr = torch.argmax(batch_mrr, dim=1)
378
+ # add 1 to the rank
379
+ batch_mrr += 1
380
+ # divide by the rank
381
+ batch_mrr = 1 / batch_mrr.float()
382
+ mrr_list = batch_mrr.tolist()
383
+
384
+
385
+
386
+
387
+ results['hit@1'] = hit1_list
388
+ results['hit@5'] = hit5_list
389
+ results['recall@20'] = recall20_list
390
+ results['mrr'] = mrr_list
391
+
392
+
393
+
394
+
395
+
396
+
397
+ return results
398
+
399
+
400
+
401
+ # ***** evaluate *****
402
+ @torch.no_grad()
403
+ def evaluate(reranker, test_loader):
404
+
405
+
406
+ reranker.eval()
407
+
408
+ all_results = {
409
+ "hit@1": [],
410
+ "hit@5": [],
411
+ "recall@20": [],
412
+ "mrr": []
413
+ }
414
+ avg_results = {
415
+ "hit@1": 0,
416
+ "hit@5": 0,
417
+ "recall@20": 0,
418
+ "mrr": 0
419
+ }
420
+
421
+
422
+ # save the scores and ans_ids, and pred_ids
423
+ pred_list = []
424
+ scores_cand_list = []
425
+ ans_ids_list = []
426
+ # use tqdm to show the progress
427
+ for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
428
+ batch = move_to_cuda(batch)
429
+
430
+ # Check if the model is wrapped in DataParallel
431
+ if isinstance(reranker, nn.DataParallel):
432
+ scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
433
+ else:
434
+ scores_cand = reranker.eval_batch(batch)
435
+
436
+
437
+ # ans_ids
438
+ ans_ids = batch['ans_ids']
439
+
440
+ results = batch_evaluator(skb, scores_cand, ans_ids, batch)
441
+
442
+
443
+ for key in results.keys():
444
+ all_results[key].extend(results[key])
445
+
446
+ # save the scores and ans_ids, and pred_ids
447
+ pred_list.extend(batch['pred_ids'])
448
+ scores_cand_list.extend(scores_cand.cpu().tolist())
449
+ ans_ids_list.extend(ans_ids)
450
+
451
+
452
+
453
+ for key in avg_results.keys():
454
+ avg_results[key] = np.mean(all_results[key])
455
+
456
+ print(f"Results: {avg_results}")
457
+
458
+
459
+
460
+ return avg_results
461
+
462
+
463
+ # ***** train *****
464
+ def main(train_data, val_data, test_data, skb, dataset_name, args):
465
+
466
+
467
+ epochs = args.epochs
468
+ device = args.device
469
+
470
+ train_size = args.train_batch_size
471
+ test_size = 64
472
+
473
+ train_dataset = TrainDataset(train_data)
474
+ train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
475
+
476
+ test_dataset = TestDataset(test_data)
477
+ test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
478
+
479
+ val_dataset = TestDataset(val_data)
480
+ val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
481
+
482
+
483
+ # ***** Model *****
484
+ reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
485
+ save_dir = f"./data/checkpoints/{dataset_name}/path"
486
+ os.makedirs(save_dir, exist_ok=True)
487
+
488
+ reranker.to(device)
489
+ # # parallel processing
490
+ reranker = nn.DataParallel(reranker)
491
+
492
+
493
+ optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
494
+ best_val_hit1 = float('-inf')
495
+
496
+
497
+ val_results = evaluate(reranker, val_loader)
498
+ print(f"Val evaluation")
499
+ print(val_results)
500
+
501
+
502
+ test_results = evaluate(reranker, test_loader)
503
+ print(f"Test evaluation")
504
+ print(test_results)
505
+
506
+ # log both val and test results
507
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
508
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
509
+
510
+ best_test_results = {}
511
+ for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
512
+ total_loss = 0.0
513
+ reranker.train()
514
+ count = 0
515
+ total_instances = 0
516
+
517
+ for batch in tqdm(train_loader):
518
+ # print(batch)
519
+ batch = move_to_cuda(batch)
520
+ # print(batch)
521
+
522
+ scores_pos, scores_neg = reranker(batch)
523
+
524
+ # batch_loss = pairwise_loss(scores_pos, scores_neg)
525
+ batch_loss = loss_fn(scores_pos, scores_neg)
526
+
527
+ # clear optimizer
528
+ optimizer.zero_grad()
529
+ batch_loss.backward()
530
+ optimizer.step()
531
+
532
+ # total_loss += batch_loss.item()
533
+ count += 1
534
+ # compute the average loss
535
+ total_instances += scores_pos.shape[0]
536
+ total_loss += batch_loss.item()
537
+
538
+
539
+ train_loss = total_loss / total_instances
540
+
541
+ print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
542
+
543
+
544
+
545
+ val_results = evaluate(reranker, val_loader)
546
+ print(f"Val evaluation")
547
+ print(val_results)
548
+
549
+
550
+ test_results = evaluate(reranker, test_loader)
551
+ print(f"Test evaluation")
552
+ print(test_results)
553
+
554
+ # log both val and test results
555
+ wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
556
+ 'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
557
+ 'train_loss': train_loss})
558
+
559
+
560
+ # save the best model when val hit1 is the highest
561
+ hit1 = val_results['hit@1']
562
+ if best_val_hit1 < hit1:
563
+ best_val_hit1 = hit1
564
+
565
+ save_path = f"{save_dir}/best_{best_val_hit1}.pth"
566
+
567
+ if isinstance(reranker, nn.DataParallel):
568
+ torch.save(reranker.module.state_dict(), save_path)
569
+ else:
570
+ torch.save(reranker.state_dict(), save_path)
571
+ print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
572
+
573
+ args.checkpoint_path = save_path
574
+ best_test_results = test_results
575
+
576
+
577
+
578
+ # save last epoch checkopint
579
+ save_path = f"{save_dir}/last_{hit1}.pth"
580
+ if isinstance(reranker, nn.DataParallel):
581
+ torch.save(reranker.module.state_dict(), save_path)
582
+ else:
583
+ torch.save(reranker.state_dict(), save_path)
584
+ print(f"Final checkpoint saved at {save_path}")
585
+
586
+
587
+
588
+ # ***** save the results *****
589
+ results = []
590
+ results.append(
591
+ {
592
+ "config": vars(args),
593
+ "test_results": best_test_results
594
+ }
595
+ )
596
+ # save the results to json
597
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
598
+ output_dir = f"./data/outputs/{dataset_name}"
599
+ os.makedirs(output_dir, exist_ok=True)
600
+ with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
601
+ json.dump(results, f, indent=4)
602
+
603
+ print(best_test_results)
604
+
605
+ def get_concat_num(combo):
606
+ """
607
+ Determine the value of concat_num based on the combination of embeddings.
608
+ - score_vec adds +1
609
+ - text_emb adds +1
610
+ - symb_enc adds +3
611
+ """
612
+ concat_num = 0
613
+ if combo.get("score_vec", False): # If score_vec is True
614
+ concat_num += 1
615
+ if combo.get("text_emb", False): # If text_emb is True
616
+ concat_num += 1
617
+ if combo.get("symb_enc", False): # If symb_enc is True
618
+ concat_num += 3
619
+
620
+
621
+ return concat_num
622
+
623
+
624
+
625
+ def parse_args():
626
+
627
+ parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
628
+
629
+ # Add arguments for model configurations
630
+ parser.add_argument("--train_batch_size", type=int, default=256, help="Batch size for training or evaluation.")
631
+ parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate for optimizer.")
632
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
633
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
634
+
635
+ # Add arguments for the dataset
636
+ parser.add_argument("--dataset_name", type=str, default="prime", help="Name of the dataset to use.")
637
+ # paths
638
+ parser.add_argument("--train_path", type=str, default=f"../prime_train.pkl", help="Path to the training data.")
639
+ parser.add_argument("--test_path", type=str, default=f"../prime_test.pkl", help="Path to the test data.")
640
+ parser.add_argument("--val_path", type=str, default=f"../prime_val.pkl", help="Path to the validation data.")
641
+
642
+ # add concat_num
643
+ parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
644
+
645
+ # checkpoint save path
646
+ parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
647
+ parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
648
+
649
+
650
+ # Parse the base arguments
651
+ args = parser.parse_args()
652
+ return args
653
+
654
+
655
+ if __name__ == "__main__":
656
+
657
+
658
+ base_args = parse_args()
659
+
660
+ dataset_name = base_args.dataset_name
661
+ train_path = base_args.train_path
662
+ test_path = base_args.test_path
663
+ val_path = base_args.val_path
664
+
665
+
666
+ with open(test_path, "rb") as f:
667
+ test_data = pkl.load(f)
668
+
669
+ with open(train_path, "rb") as f:
670
+ train_data = pkl.load(f)
671
+
672
+ with open(val_path, "rb") as f:
673
+ val_data = pkl.load(f)
674
+
675
+ # load skb
676
+ skb = load_skb(dataset_name)
677
+
678
+ # set all
679
+ combo = {
680
+ "text_emb": True,
681
+ "score_vec": True,
682
+ "symb_enc": True
683
+ }
684
+ concat_num = get_concat_num(combo)
685
+
686
+ wandb.init(project=f'Reranking-{dataset_name}', name=f"path")
687
+ args = argparse.Namespace(**vars(base_args), **combo)
688
+ args.concat_num = concat_num
689
+
690
+ main(train_data, val_data, test_data, skb, dataset_name, args)
691
+
Reranking/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ # Get the absolute path of the current script
4
+ current_file = Path(__file__).resolve()
5
+ project_root = current_file.parents[1]
6
+ # Add the project root to the system path
7
+ sys.path.append(str(project_root))
8
+
9
+ import random
10
+ import torch
11
+ import os
12
+ from stark_qa.evaluator import Evaluator
13
+ import torch.nn as nn
14
+ from typing import Any, Union, List, Dict
15
+
16
+
17
+ class ModelForSTaRKQA(nn.Module):
18
+
19
+ def __init__(self, skb, query_emb_dir='.'):
20
+ """
21
+ Initializes the model with the given knowledge base.
22
+
23
+ Args:
24
+ skb: Knowledge base containing candidate information.
25
+ """
26
+ super(ModelForSTaRKQA, self).__init__()
27
+ self.skb = skb
28
+
29
+ self.candidate_ids = skb.candidate_ids
30
+ self.evaluator = Evaluator(self.candidate_ids)
31
+
32
+ def evaluate(self,
33
+ pred_dict: Dict[int, float],
34
+ answer_ids: Union[torch.LongTensor, List[Any]],
35
+ metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
36
+ **kwargs: Any) -> Dict[str, float]:
37
+ """
38
+ Evaluates the predictions using the specified metrics.
39
+
40
+ Args:
41
+ pred_dict (Dict[int, float]): Predicted answer ids or scores.
42
+ answer_ids (torch.LongTensor): Ground truth answer ids.
43
+ metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
44
+ 'precision@k', 'map@k', 'ndcg@k'.
45
+
46
+ Returns:
47
+ Dict[str, float]: A dictionary of evaluation metrics.
48
+ """
49
+ return self.evaluator(pred_dict, answer_ids, metrics)
50
+
51
+ def evaluate_batch(self,
52
+ pred_ids: List[int],
53
+ pred: torch.Tensor,
54
+ answer_ids: Union[torch.LongTensor, List[Any]],
55
+ metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
56
+ **kwargs: Any) -> Dict[str, float]:
57
+ return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
58
+
59
+
60
+ def seed_everything(seed=0):
61
+ random.seed(seed)
62
+ torch.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+ os.environ['PYTHONHASHSEED'] = str(seed)
65
+ torch.backends.cudnn.deterministic = True
66
+ torch.backends.cudnn.benchmark = False
67
+
68
+
69
+ def move_to_cuda(sample):
70
+ if len(sample) == 0:
71
+ return {}
72
+
73
+ def _move_to_cuda(maybe_tensor):
74
+ if torch.is_tensor(maybe_tensor):
75
+ return maybe_tensor.cuda()
76
+ elif isinstance(maybe_tensor, dict):
77
+ return {
78
+ key: _move_to_cuda(value)
79
+ for key, value in maybe_tensor.items()
80
+ }
81
+ # elif isinstance(maybe_tensor, list):
82
+ # return [_move_to_cuda(x) for x in maybe_tensor]
83
+ else:
84
+ return maybe_tensor
85
+
86
+ return _move_to_cuda(sample)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ print("Testing Utils")