framework
Browse filesThis view is limited to 50 files because it contains too many changes. 聽
See raw diff
- Planning/__pycache__/model.cpython-311.pyc +0 -0
- Planning/__pycache__/prompts.cpython-311.pyc +0 -0
- Planning/__pycache__/utils.cpython-311.pyc +0 -0
- Planning/data/finetune/amazon/1000.json +0 -0
- Planning/data/finetune/amazon/llama_ft.jsonl +0 -0
- Planning/data/finetune/combine_triplets.py +93 -0
- Planning/data/finetune/mag/1000.json +0 -0
- Planning/data/finetune/mag/llama_ft.jsonl +0 -0
- Planning/data/finetune/prime/1000.json +55 -0
- Planning/data/finetune/prime/llama_ft.jsonl +0 -0
- Planning/data/get_train_data/__pycache__/prompts.cpython-311.pyc +0 -0
- Planning/data/get_train_data/get_llm_data.py +237 -0
- Planning/data/get_train_data/post_process_data.py +58 -0
- Planning/data/get_train_data/prompts.py +274 -0
- Planning/data/train_eval.py +223 -0
- Planning/model.py +89 -0
- Planning/utils.py +4 -0
- Reasoning/__pycache__/mor4node.cpython-311.pyc +0 -0
- Reasoning/__pycache__/mor4node_copy.cpython-311.pyc +0 -0
- Reasoning/__pycache__/mor4path.cpython-311.pyc +0 -0
- Reasoning/__pycache__/ptp_mor4node.cpython-311.pyc +0 -0
- Reasoning/__pycache__/utils.cpython-311.pyc +0 -0
- Reasoning/mor4path.py +435 -0
- Reasoning/structural_retriever/__pycache__/stru4path.cpython-311.pyc +0 -0
- Reasoning/structural_retriever/stru4path.py +305 -0
- Reasoning/text_retrievers/__init__.py +4 -0
- Reasoning/text_retrievers/__pycache__/__init__.cpython-311.pyc +0 -0
- Reasoning/text_retrievers/__pycache__/ada.cpython-311.pyc +0 -0
- Reasoning/text_retrievers/__pycache__/bm25.cpython-311.pyc +0 -0
- Reasoning/text_retrievers/__pycache__/contriever.cpython-311.pyc +0 -0
- Reasoning/text_retrievers/__pycache__/stark_model.cpython-311.pyc +0 -0
- Reasoning/text_retrievers/ada.py +66 -0
- Reasoning/text_retrievers/bm25.py +108 -0
- Reasoning/text_retrievers/contriever.py +78 -0
- Reasoning/text_retrievers/stark_model.py +151 -0
- Reasoning/utils.py +116 -0
- Reranking/__pycache__/rerank.cpython-311.pyc +0 -0
- Reranking/__pycache__/utils.cpython-311.pyc +0 -0
- Reranking/data/checkpoints/amazon/best.pth +3 -0
- Reranking/data/checkpoints/mag/best.pth +3 -0
- Reranking/data/checkpoints/prime/best.pth +3 -0
- Reranking/rerank.py +375 -0
- Reranking/rerankers/__pycache__/node.cpython-311.pyc +0 -0
- Reranking/rerankers/__pycache__/path.cpython-311.pyc +0 -0
- Reranking/rerankers/node.py +87 -0
- Reranking/rerankers/path.py +107 -0
- Reranking/train_eval_path_amazon.py +694 -0
- Reranking/train_eval_path_mag.py +694 -0
- Reranking/train_eval_path_prime.py +691 -0
- Reranking/utils.py +90 -0
Planning/__pycache__/model.cpython-311.pyc
ADDED
Binary file (3.89 kB). View file
|
|
Planning/__pycache__/prompts.cpython-311.pyc
ADDED
Binary file (13.1 kB). View file
|
|
Planning/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (493 Bytes). View file
|
|
Planning/data/finetune/amazon/1000.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Planning/data/finetune/amazon/llama_ft.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Planning/data/finetune/combine_triplets.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
path = "./prime/1000.json"
|
6 |
+
with open(path, 'r') as f:
|
7 |
+
data = json.load(f)
|
8 |
+
|
9 |
+
def combine_triplets_by_boundary_no_loops(NTar):
|
10 |
+
combined = NTar[:]
|
11 |
+
changes = True # Track if merges happen during an iteration
|
12 |
+
|
13 |
+
while changes:
|
14 |
+
changes = False # Reset changes for this iteration
|
15 |
+
new_combined = [] # To hold the merged triplets
|
16 |
+
used = [False] * len(combined) # Track which triplets have been processed
|
17 |
+
|
18 |
+
for i in range(len(combined)):
|
19 |
+
if used[i]:
|
20 |
+
continue # Skip already processed triplets
|
21 |
+
|
22 |
+
current_triplet = combined[i] # Convert triplet to a mutable list
|
23 |
+
used[i] = True
|
24 |
+
merged = False # Track if the current triplet is merged with another
|
25 |
+
|
26 |
+
for j in range(len(combined)):
|
27 |
+
if i != j and not used[j]:
|
28 |
+
other_triplet = combined[j]
|
29 |
+
|
30 |
+
# Check if the current triplet can be merged with another
|
31 |
+
if current_triplet[-1] == other_triplet[0]: # Current's last matches other's first
|
32 |
+
current_triplet.extend(other_triplet[1:]) # Merge, excluding duplicate entity
|
33 |
+
used[j] = True
|
34 |
+
merged = True
|
35 |
+
changes = True # A merge occurred
|
36 |
+
break
|
37 |
+
elif current_triplet[0] == other_triplet[-1]: # Current's first matches other's last
|
38 |
+
current_triplet = other_triplet[:-1] + current_triplet # Merge, excluding duplicate entity
|
39 |
+
used[j] = True
|
40 |
+
merged = True
|
41 |
+
changes = True
|
42 |
+
break
|
43 |
+
|
44 |
+
# After merging or if no merge happened, add the triplet to the new list
|
45 |
+
new_combined.append(current_triplet)
|
46 |
+
|
47 |
+
# Update the combined list with the newly merged triplets
|
48 |
+
combined = new_combined
|
49 |
+
|
50 |
+
return combined
|
51 |
+
|
52 |
+
|
53 |
+
def combine_tar_ntar(tar, ntar):
|
54 |
+
for nt in ntar:
|
55 |
+
for i in range(len(tar)):
|
56 |
+
if tar[i][-1] == nt[0]:
|
57 |
+
tar[i] = tar[i]+nt[1:]
|
58 |
+
elif tar[i][0] == nt[-1]:
|
59 |
+
tar[i] = nt[:-1]+tar[i]
|
60 |
+
return tar
|
61 |
+
|
62 |
+
def check_order(routes, target):
|
63 |
+
for i in range(len(routes)):
|
64 |
+
if routes[i][-1] != target:
|
65 |
+
if routes[i][0] != target:
|
66 |
+
raise ValueError(f"Wrong order: {routes[i]}")
|
67 |
+
else:
|
68 |
+
routes[i] = routes[i][::-1]
|
69 |
+
return routes
|
70 |
+
|
71 |
+
routes_list = []
|
72 |
+
restrictions_list = []
|
73 |
+
for i in range(len(data)):
|
74 |
+
triplets = data[i]['Triplets']
|
75 |
+
target = data[i]['Target']
|
76 |
+
restrictions = data[i]['Restriction']
|
77 |
+
tar = []
|
78 |
+
ntar = []
|
79 |
+
for tp in triplets:
|
80 |
+
if target in tp:
|
81 |
+
tar.append(tp)
|
82 |
+
else:
|
83 |
+
ntar.append(tp)
|
84 |
+
if len(ntar) > 0:
|
85 |
+
ntar = combine_triplets_by_boundary_no_loops(ntar)
|
86 |
+
routes = combine_tar_ntar(tar, ntar)
|
87 |
+
else:
|
88 |
+
routes = tar
|
89 |
+
print(target)
|
90 |
+
routes = check_order(routes, target)
|
91 |
+
routes_list.append(routes)
|
92 |
+
restrictions_list.append(restrictions)
|
93 |
+
|
Planning/data/finetune/mag/1000.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Planning/data/finetune/mag/llama_ft.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Planning/data/finetune/prime/1000.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"query": "Is the lens-specific intermediate filament-like protein, filensin, which is encoded by a gene expressed in the lens of camera-type eyes but not found in nasal cavity epithelial cells, involved in providing structural support to the ocular lens?",
|
4 |
+
"answer": {
|
5 |
+
"Triplets": [
|
6 |
+
[
|
7 |
+
"anatomy",
|
8 |
+
"expression present",
|
9 |
+
"gene/protein"
|
10 |
+
],
|
11 |
+
[
|
12 |
+
"anatomy",
|
13 |
+
"expression absent",
|
14 |
+
"gene/protein"
|
15 |
+
]
|
16 |
+
],
|
17 |
+
"Restriction": {
|
18 |
+
"anatomy": [
|
19 |
+
"lens of camera-type eyes",
|
20 |
+
"nasal cavity epithelial cells"
|
21 |
+
],
|
22 |
+
"gene/protein": [
|
23 |
+
"filensin"
|
24 |
+
],
|
25 |
+
"cellular_component": [
|
26 |
+
"ocular lens"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
"Target": "gene/protein"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"query": "Which anatomical structures lack the expression of genes or proteins involved in the interaction with the fucose metabolism pathway?",
|
34 |
+
"answer": {
|
35 |
+
"Triplets": [
|
36 |
+
[
|
37 |
+
"anatomy",
|
38 |
+
"expression absent",
|
39 |
+
"gene/protein"
|
40 |
+
],
|
41 |
+
[
|
42 |
+
"gene/protein",
|
43 |
+
"interacts with",
|
44 |
+
"pathway"
|
45 |
+
]
|
46 |
+
],
|
47 |
+
"Restriction": {
|
48 |
+
"pathway": [
|
49 |
+
"fucose metabolism"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
"Target": "anatomy"
|
53 |
+
}
|
54 |
+
}
|
55 |
+
]
|
Planning/data/finetune/prime/llama_ft.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Planning/data/get_train_data/__pycache__/prompts.cpython-311.pyc
ADDED
Binary file (13.1 kB). View file
|
|
Planning/data/get_train_data/get_llm_data.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Desc: This file is used to get the training data from the LLM
|
3 |
+
|
4 |
+
"""
|
5 |
+
import sys
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# Get the absolute path of the current script
|
9 |
+
current_file = Path(__file__).resolve()
|
10 |
+
project_root = current_file.parents[3]
|
11 |
+
|
12 |
+
# Add the project root to the system path
|
13 |
+
sys.path.append(str(project_root))
|
14 |
+
|
15 |
+
from stark_qa import load_qa
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
from openai import AzureOpenAI
|
20 |
+
import json
|
21 |
+
import openai
|
22 |
+
from prompts import prompts
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
MAG:
|
32 |
+
sys_content: 478/query
|
33 |
+
output: 45/query
|
34 |
+
input: 25/query
|
35 |
+
1000 queries
|
36 |
+
|
37 |
+
total price:
|
38 |
+
1. o1: $13.29
|
39 |
+
2. o3mini: $0.97
|
40 |
+
3. deepseek-chat: $0.24
|
41 |
+
4. deepseek-reasoner: $0.49
|
42 |
+
|
43 |
+
Amazon:
|
44 |
+
sys_content: 478/query
|
45 |
+
|
46 |
+
"""
|
47 |
+
|
48 |
+
# get the prompt for different datasets
|
49 |
+
def get_sys_content(dataset_name):
|
50 |
+
"""
|
51 |
+
input:
|
52 |
+
dataset_name: the name of the dataset
|
53 |
+
output:
|
54 |
+
sys_content: the sys_content for the dataset
|
55 |
+
"""
|
56 |
+
sys_content = prompts(dataset_name)
|
57 |
+
|
58 |
+
|
59 |
+
return sys_content
|
60 |
+
|
61 |
+
# get the response from the llm
|
62 |
+
def get_response(sys_content, user_content):
|
63 |
+
|
64 |
+
messages = [{"role": "system", "content": sys_content},
|
65 |
+
{"role": "user", "content": user_content}
|
66 |
+
]
|
67 |
+
|
68 |
+
chat_completion = client.chat.completions.create(
|
69 |
+
messages=messages,
|
70 |
+
model=parameters['azure']['model'], # parameters['azure']['model'], parameters['openai']['model']
|
71 |
+
# temperature=0,
|
72 |
+
seed=576879897,
|
73 |
+
)
|
74 |
+
response = chat_completion.choices[0].message.content
|
75 |
+
|
76 |
+
# print(messages)
|
77 |
+
# print(response)
|
78 |
+
|
79 |
+
return response
|
80 |
+
|
81 |
+
# save the outputs to json file
|
82 |
+
def save_json(data, dataset_name):
|
83 |
+
"""
|
84 |
+
input:
|
85 |
+
data: the data to be saved
|
86 |
+
dataset_name: the name of the dataset
|
87 |
+
"""
|
88 |
+
|
89 |
+
file_dir = f"/home/yongjia/dgl/Yongjia/MOE/Reasoner/data/finetune/{dataset_name}"
|
90 |
+
os.makedirs(file_dir, exist_ok=True)
|
91 |
+
file_path = f"{file_dir}/1000_{parameters['azure']['model']}.json"
|
92 |
+
|
93 |
+
with open(file_path, 'w') as f:
|
94 |
+
json.dump(data, f, indent=4)
|
95 |
+
print(f"Saved to {file_path}")
|
96 |
+
|
97 |
+
# get the reasoning graphs for a dataset
|
98 |
+
def get_rg(dataset_name):
|
99 |
+
"""
|
100 |
+
input:
|
101 |
+
dataset_name: the name of the dataset
|
102 |
+
output:
|
103 |
+
rg: the reasoning graph for the dataset
|
104 |
+
"""
|
105 |
+
|
106 |
+
# get the prompt for the dataset
|
107 |
+
sys_content = get_sys_content(dataset_name)
|
108 |
+
|
109 |
+
# get qa dataset
|
110 |
+
qa = load_qa(dataset_name)
|
111 |
+
train_qa = qa.get_subset('train')
|
112 |
+
|
113 |
+
# we sample 1000 queries from the training set
|
114 |
+
pair_list = []
|
115 |
+
failure_count = 0
|
116 |
+
for i in range(1500):
|
117 |
+
query, q_id, ans_ids, _ = train_qa[i]
|
118 |
+
|
119 |
+
# call the llm to get the reasoning graph
|
120 |
+
response = get_response(sys_content, query)
|
121 |
+
print(response)
|
122 |
+
|
123 |
+
# process the response
|
124 |
+
|
125 |
+
if dataset_name == 'prime':
|
126 |
+
output = {
|
127 |
+
"Triplets":[],
|
128 |
+
"Restriction": [],
|
129 |
+
"Target": ""
|
130 |
+
}
|
131 |
+
|
132 |
+
try:
|
133 |
+
response = response.split('\n')
|
134 |
+
triplets_raw = response[0].replace('Triplets:', '').strip()
|
135 |
+
triplets = json.loads(triplets_raw)
|
136 |
+
output['Triplets'] = triplets
|
137 |
+
|
138 |
+
restriction_raw = response[1].replace('Restriction:', '').strip()
|
139 |
+
restriction = json.loads(restriction_raw)
|
140 |
+
output['Restriction'] = restriction
|
141 |
+
|
142 |
+
target = response[2].replace('Target:', '').strip()
|
143 |
+
output['Target'] = target
|
144 |
+
except:
|
145 |
+
failure_count += 1
|
146 |
+
continue
|
147 |
+
|
148 |
+
elif dataset_name == 'mag' or dataset_name == 'amazon':
|
149 |
+
output = {
|
150 |
+
"Metapath": "",
|
151 |
+
"Restriction": [],
|
152 |
+
}
|
153 |
+
|
154 |
+
try:
|
155 |
+
response = response.split('\n')
|
156 |
+
metapath = response[0].replace('Metapath:', '').strip()
|
157 |
+
output['Metapath'] = metapath
|
158 |
+
|
159 |
+
restriction_raw = response[1].replace('Restriction:', '').strip()
|
160 |
+
restriction = json.loads(restriction_raw)
|
161 |
+
output['Restriction'] = restriction
|
162 |
+
except:
|
163 |
+
failure_count += 1
|
164 |
+
continue
|
165 |
+
|
166 |
+
else:
|
167 |
+
raise ValueError('The dataset is not supported')
|
168 |
+
|
169 |
+
pair = {'query': query, 'answer': output}
|
170 |
+
|
171 |
+
pair_list.append(pair)
|
172 |
+
|
173 |
+
if len(pair_list) == 1000:
|
174 |
+
break
|
175 |
+
|
176 |
+
# save the output to json file
|
177 |
+
save_json(pair_list, dataset_name)
|
178 |
+
print(f"Failure count: {failure_count}")
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
# Argument parser setup
|
183 |
+
parser = argparse.ArgumentParser(description="Load LLM parameters and initialize API clients.")
|
184 |
+
|
185 |
+
# Dataset name
|
186 |
+
parser.add_argument("--dataset_name", type=str, required=True,
|
187 |
+
choices=["mag", "amazon", "prime"],
|
188 |
+
help="Specify the dataset to use.")
|
189 |
+
|
190 |
+
# Model selection
|
191 |
+
parser.add_argument("--model", type=str, required=True,
|
192 |
+
choices=["gpt-4o-mini-20240718", "gpt-4o-2024-05-13",
|
193 |
+
"deepseek-reasoner", "gpt-o1-2024-12-17",
|
194 |
+
"o3-mini-2025-01-31"],
|
195 |
+
help="Specify the model to use.")
|
196 |
+
|
197 |
+
# Azure API parameters
|
198 |
+
parser.add_argument("--azure_api_key", type=str, default=None, help="Azure API Key")
|
199 |
+
parser.add_argument("--azure_endpoint", type=str, default=None, help="Azure API Endpoint")
|
200 |
+
parser.add_argument("--azure_api_version", type=str, default=None, help="Azure API Version")
|
201 |
+
|
202 |
+
# OpenAI API parameters
|
203 |
+
parser.add_argument("--openai_api_key", type=str, default=None, help="OpenAI API Key")
|
204 |
+
parser.add_argument("--openai_endpoint", type=str, default=None, help="OpenAI API Endpoint")
|
205 |
+
|
206 |
+
args = parser.parse_args()
|
207 |
+
|
208 |
+
# Initialize parameters dictionary
|
209 |
+
parameters = {
|
210 |
+
"azure": {
|
211 |
+
"api_key": args.azure_api_key,
|
212 |
+
"azure_endpoint": args.azure_endpoint,
|
213 |
+
"api_version": args.azure_api_version,
|
214 |
+
},
|
215 |
+
"openai": {
|
216 |
+
"api_key": args.openai_api_key,
|
217 |
+
"endpoint": args.openai_endpoint,
|
218 |
+
}
|
219 |
+
}
|
220 |
+
|
221 |
+
|
222 |
+
# Determine which API client to use
|
223 |
+
if parameters["openai"]["api_key"]:
|
224 |
+
client = openai.OpenAI(
|
225 |
+
base_url=parameters["openai"]["endpoint"],
|
226 |
+
api_key=parameters["openai"]["api_key"],
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
client = AzureOpenAI(
|
230 |
+
azure_endpoint=parameters["azure"]["azure_endpoint"],
|
231 |
+
api_key=parameters["azure"]["api_key"],
|
232 |
+
api_version=parameters["azure"]["api_version"],
|
233 |
+
)
|
234 |
+
|
235 |
+
get_rg(args.dataset_name)
|
236 |
+
|
237 |
+
|
Planning/data/get_train_data/post_process_data.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Description: This script is used to post-process the data after getting it from the LLM.
|
3 |
+
|
4 |
+
input: data from llm
|
5 |
+
output: data for llama finetuning
|
6 |
+
|
7 |
+
"""
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import copy
|
11 |
+
|
12 |
+
def save_data(data, dataset_name, model_name):
|
13 |
+
file_dir = f"../data/finetune/{dataset_name}"
|
14 |
+
os.makedirs(file_dir, exist_ok=True)
|
15 |
+
file_path = os.path.join(file_dir, f"llama_ft_{model_name}.jsonl")
|
16 |
+
|
17 |
+
with open(file_path, "w") as f:
|
18 |
+
for d in data:
|
19 |
+
f.write(json.dumps(d) + "\n")
|
20 |
+
|
21 |
+
print(f"Data saved to {file_path}")
|
22 |
+
|
23 |
+
def process(sample, dataset_name, model_name):
|
24 |
+
"""
|
25 |
+
input: sample from llm
|
26 |
+
output: sample for llama finetuning
|
27 |
+
"""
|
28 |
+
|
29 |
+
output_format = {"conversations": [{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]}
|
30 |
+
ft_list = []
|
31 |
+
for i in range(len(sample)):
|
32 |
+
output = copy.deepcopy(output_format)
|
33 |
+
output["conversations"][0]["content"] = sample[i]["query"]
|
34 |
+
output["conversations"][1]["content"] = str(sample[i]["answer"])
|
35 |
+
ft_list.append(output)
|
36 |
+
|
37 |
+
# save data
|
38 |
+
save_data(ft_list, dataset_name, model_name)
|
39 |
+
|
40 |
+
|
41 |
+
# ***** Main *****
|
42 |
+
if __name__ == "__main__":
|
43 |
+
# read data
|
44 |
+
dataset_name_list = ["mag"]
|
45 |
+
model_names = ["gpt-4o-mini-20240718", "o3-mini-2025-01-31", "gpt-o1-2024-12-17", "gpt-4o-2024-05-13", "gpt-4o-mini-20240718", "gpt35-1106"] # gpt-o1-2024-12-17, "gpt-4o-mini-20240718", "gpt35-1106", o3-mini-2025-01-31
|
46 |
+
|
47 |
+
for dataset_name in dataset_name_list:
|
48 |
+
for model_name in model_names:
|
49 |
+
relative_path = f"finetune/{dataset_name}/1000.json" # f"finetune/{dataset_name}/1000_{dataset_name}.json"
|
50 |
+
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Directory of the current script
|
51 |
+
file_path = os.path.join(current_dir, relative_path)
|
52 |
+
with open(f"{file_path}", "r") as f:
|
53 |
+
sample = json.load(f)
|
54 |
+
|
55 |
+
# process data
|
56 |
+
process(sample, dataset_name, model_name)
|
57 |
+
print(f"Processing {model_name} for {dataset_name} is done.")
|
58 |
+
break
|
Planning/data/get_train_data/prompts.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def prompts(dataset_name):
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
input:
|
6 |
+
dataset_name: the name of the dataset
|
7 |
+
output:
|
8 |
+
prompt: the prompt for the dataset
|
9 |
+
"""
|
10 |
+
if dataset_name == "amazon":
|
11 |
+
sys_prompt = """
|
12 |
+
|
13 |
+
You are a reasoning graph finder agent. Your role is to:
|
14 |
+
1. Identify the underlying **meta-path** from a given question, which consists of the **entity types** at each reasoning step.
|
15 |
+
2. Extract the **content restriction** for each **entity type** based on the question. If there is no restriction for an entity type, leave its value empty.
|
16 |
+
|
17 |
+
You will be provided with a predefined **Entity Type List**. Only use the entity types from this list when constructing the meta-path and restrictions. Your response must be concise and strictly adhere to the specified **Output Format**.
|
18 |
+
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Define the entity type list
|
22 |
+
entity_type_list = """
|
23 |
+
Entity Type List:
|
24 |
+
- **brand**
|
25 |
+
- **category**
|
26 |
+
- **product**
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
demonstrations = """
|
31 |
+
Here are some examples:
|
32 |
+
**Question 1:**
|
33 |
+
What are some 5x4 inch team sports decals by Football Fanatics that are easy to apply on exterior surfaces?
|
34 |
+
|
35 |
+
**Answer 1:**
|
36 |
+
Metapath: brand -> product
|
37 |
+
Restriction: {"brand": ["Football Fanatics"], "product": ["5x4 inch team sports decals by Football Fanatics that are easy to apply on exterior surfaces"]}
|
38 |
+
|
39 |
+
**Question 2:**
|
40 |
+
Looking for men's insulated pants suitable for heavy rain, up to 10k water resistance, and with reinforced cuffs specifically for added protection in parking lots. Any suggestions?
|
41 |
+
|
42 |
+
**Answer 2:**
|
43 |
+
Metapath: category -> product
|
44 |
+
Restriction: {"category": ["pants"], "product": ["men\'s insulated pants suitable for heavy rain", "reinforced cuffs specifically for added protection in parking lots"]}
|
45 |
+
|
46 |
+
**Question 3:**
|
47 |
+
Can you recommend a dive mask that would work well with the Dive Mask Freediving Mask Spearfishing Mask Low Volume Mini Mask? I need the two masks to be compatible for my diving activities.
|
48 |
+
|
49 |
+
**Answer 3:**
|
50 |
+
Metapath: product -> product
|
51 |
+
Restriction: {"product": ["work well with the Dive Mask Freediving Mask Spearfishing Mask Low Volume Mini Mask", "compatible for diving activities", "a dive mask"]}
|
52 |
+
|
53 |
+
**Question 4:**
|
54 |
+
What are some UV protective women's golf jackets from PUMA? I'm cautious about skin protection, but I'm a big fan of PUMA.
|
55 |
+
|
56 |
+
**Answer 4:**
|
57 |
+
Metapath: brand -> product <- category
|
58 |
+
Restriction: {"brand": ["PUMA"], "category": ["golf jackets"], "product": ["UV protective women\'s golf jackets from PUMA", "skin protection"]}
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
"""
|
63 |
+
|
64 |
+
output_format = """
|
65 |
+
|
66 |
+
**Output Format**
|
67 |
+
Metapath: "",
|
68 |
+
Restriction: {}
|
69 |
+
"""
|
70 |
+
|
71 |
+
sys_content = sys_prompt + entity_type_list + output_format + demonstrations
|
72 |
+
|
73 |
+
|
74 |
+
elif dataset_name == "mag":
|
75 |
+
|
76 |
+
sys_prompt = """
|
77 |
+
You are a reasoning finder agent. Your role is to:
|
78 |
+
1. Identify the underlying **meta-path** from a given question, which consists of the **entity types** at each reasoning step.
|
79 |
+
2. Extract the **content restriction** for each **entity type** based on the question. If there is no restriction for an entity type, leave its value empty.
|
80 |
+
|
81 |
+
You will be provided with a predefined **Entity Type List**. Only use the entity types from this list when constructing the meta-path and restrictions. Your response must be concise and strictly adhere to the specified **Output Format**.
|
82 |
+
|
83 |
+
"""
|
84 |
+
|
85 |
+
entity_type_list = """
|
86 |
+
Entity Type List:
|
87 |
+
- **paper**
|
88 |
+
- **author**
|
89 |
+
- **institution**
|
90 |
+
- **field_of_study**
|
91 |
+
|
92 |
+
"""
|
93 |
+
|
94 |
+
output_format = """
|
95 |
+
**Output Format**
|
96 |
+
Metapath: "",
|
97 |
+
Restriction: {}
|
98 |
+
"""
|
99 |
+
|
100 |
+
|
101 |
+
demonstrations = """
|
102 |
+
|
103 |
+
Here are some examples:
|
104 |
+
**Question 1:**
|
105 |
+
Show me research articles on the association of quasi-periodic oscillations (QPOs) with noise in celestial bodies within the context of Bicoherence.
|
106 |
+
|
107 |
+
**Answer 1:**
|
108 |
+
Metapath: field_of_study -> paper
|
109 |
+
Restriction: {"field_of_study": ["oscillations", "physics"], "paper": ["association of quasi-periodic oscillations (QPOs) with noise in celestial bodies within the context of Bicoherence"]}
|
110 |
+
|
111 |
+
**Question 2:**
|
112 |
+
What research on water absorption in different frequency ranges have been referenced or deemed significant in the paper entitled 'High-resolution terahertz atmospheric water vapor continuum measurements?
|
113 |
+
|
114 |
+
**Answer 2:**
|
115 |
+
Metapath: paper -> paper
|
116 |
+
Restriction: {"paper": ["water absorption in different frequency ranges", "High-resolution terahertz atmospheric water vapor continuum measurements"]}
|
117 |
+
|
118 |
+
**Question 3:**
|
119 |
+
Show me publications by A.J. Turvey on the topic of supersymmetry particle searches.
|
120 |
+
|
121 |
+
**Answer 3:**
|
122 |
+
Metapath: author -> paper
|
123 |
+
Restriction: {"author": ["A.J. Turvey"], "paper": ["supersymmetry particle searches"]}
|
124 |
+
|
125 |
+
**Question 4:**
|
126 |
+
Looking for papers co-authored by someone involved in "Strain transferring mechanism analysis of the substrate-bonded FBG sensor". The papers should be in the similar field which is fiber optic strain sensors, and further discuss their development and application.
|
127 |
+
|
128 |
+
**Answer 4:**
|
129 |
+
Metapath: paper -> author -> paper
|
130 |
+
Restriction: {"paper": ["Strain transferring mechanism analysis of the substrate-bonded FBG sensor", "in the similar field which is fiber optic strain sensors, and further discuss their development and application"]}
|
131 |
+
|
132 |
+
**Question 5:**
|
133 |
+
Can you find any publications by the authors of "Tradeoffs in the Realization of Electrically Pumped Vertical External Cavity Surface Emitting Lasers," that delve into the topic of hybrid quantum well/quantum dot structures in the context of lasers?
|
134 |
+
|
135 |
+
**Answer 5:**
|
136 |
+
Metapath: paper -> author -> paper <- field_of_study
|
137 |
+
Restriction: {"paper": ["Tradeoffs in the Realization of Electrically Pumped Vertical External Cavity Surface Emitting Lasers", "hybrid quantum well/quantum dot structures in the context of lasers"], "field_of_study": ["lasers", "hybrid quantum well/quantum dot", "physics", "Emission spectrum"]}
|
138 |
+
|
139 |
+
**Question 6:**
|
140 |
+
Which publications from Altair Engineering authors focus on improving directional sensitivity across a wide range of frequencies?
|
141 |
+
|
142 |
+
**Answer 6:**
|
143 |
+
Metapath: institution -> author -> paper
|
144 |
+
Restriction: {"institution": ["Altair Engineering"], "paper": ["improving directional sensitivity across a wide range of frequencies"]}
|
145 |
+
|
146 |
+
**Question 7:**
|
147 |
+
Publications by Carlisle Companies authors on satellite instrumentation and space performance
|
148 |
+
|
149 |
+
**Answer 7:**
|
150 |
+
Metapath: institution -> author -> paper <- field_of_study
|
151 |
+
Restriction: {"institution": ["Carlisle Companies"], "paper": ["satellite instrumentation and space performance"], "field_of_study": ["satellite", 'space performance']}
|
152 |
+
|
153 |
+
"""
|
154 |
+
|
155 |
+
sys_content = sys_prompt + entity_type_list + output_format + demonstrations
|
156 |
+
|
157 |
+
elif dataset_name == "prime":
|
158 |
+
|
159 |
+
sys_prompt = """
|
160 |
+
You are a triplets extractor. Given a list of triplets and a query, please
|
161 |
+
1. extract the triplets contained in the query and give a list to me.
|
162 |
+
2. make a restriction list which contains the description of the entity in the query.
|
163 |
+
3. tell me which entity the query is asking for.
|
164 |
+
Your response must be concise and strictly adhere to the specified **Output Format**.
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
+
triplets = """
|
169 |
+
Triplets list:
|
170 |
+
[('anatomy', 'expression absent', 'gene/protein'),
|
171 |
+
('anatomy', 'expression present', 'gene/protein'),
|
172 |
+
('anatomy', 'parent-child', 'anatomy'),
|
173 |
+
('biological_process', 'interacts with', 'exposure'),
|
174 |
+
('biological_process', 'interacts with', 'gene/protein'),
|
175 |
+
('biological_process', 'parent-child', 'biological_process'),
|
176 |
+
('cellular_component', 'interacts with', 'exposure'),
|
177 |
+
('cellular_component', 'interacts with', 'gene/protein'),
|
178 |
+
('cellular_component', 'parent-child', 'cellular_component'),
|
179 |
+
('disease', 'associated with', 'gene/protein'),
|
180 |
+
('disease', 'contraindication', 'drug'),
|
181 |
+
('disease', 'indication', 'drug'),
|
182 |
+
('disease', 'linked to', 'exposure'),
|
183 |
+
('disease', 'off-label use', 'drug'),
|
184 |
+
('disease', 'parent-child', 'disease'),
|
185 |
+
('disease', 'phenotype absent', 'effect/phenotype'),
|
186 |
+
('disease', 'phenotype present', 'effect/phenotype'),
|
187 |
+
('drug', 'carrier', 'gene/protein'),
|
188 |
+
('drug', 'contraindication', 'disease'),
|
189 |
+
('drug', 'enzyme', 'gene/protein'),
|
190 |
+
('drug', 'indication', 'disease'),
|
191 |
+
('drug', 'off-label use', 'disease'),
|
192 |
+
('drug', 'side effect', 'effect/phenotype'),
|
193 |
+
('drug', 'synergistic interaction', 'drug'),
|
194 |
+
('drug', 'target', 'gene/protein'),
|
195 |
+
('drug', 'transporter', 'gene/protein'),
|
196 |
+
('effect/phenotype', 'associated with', 'gene/protein'),
|
197 |
+
('effect/phenotype', 'parent-child', 'effect/phenotype'),
|
198 |
+
('effect/phenotype', 'phenotype absent', 'disease'),
|
199 |
+
('effect/phenotype', 'phenotype present', 'disease'),
|
200 |
+
('effect/phenotype', 'side effect', 'drug'),
|
201 |
+
('exposure', 'interacts with', 'biological_process'),
|
202 |
+
('exposure', 'interacts with', 'cellular_component'),
|
203 |
+
('exposure', 'interacts with', 'gene/protein'),
|
204 |
+
('exposure', 'interacts with', 'molecular_function'),
|
205 |
+
('exposure', 'linked to', 'disease'),
|
206 |
+
('exposure', 'parent-child', 'exposure'),
|
207 |
+
('gene/protein', 'associated with', 'disease'),
|
208 |
+
('gene/protein', 'associated with', 'effect/phenotype'),
|
209 |
+
('gene/protein', 'carrier', 'drug'),
|
210 |
+
('gene/protein', 'enzyme', 'drug'),
|
211 |
+
('gene/protein', 'expression absent', 'anatomy'),
|
212 |
+
('gene/protein', 'expression present', 'anatomy'),
|
213 |
+
('gene/protein', 'interacts with', 'biological_process'),
|
214 |
+
('gene/protein', 'interacts with', 'cellular_component'),
|
215 |
+
('gene/protein', 'interacts with', 'exposure'),
|
216 |
+
('gene/protein', 'interacts with', 'molecular_function'),
|
217 |
+
('gene/protein', 'interacts with', 'pathway'),
|
218 |
+
('gene/protein', 'ppi', 'gene/protein'),
|
219 |
+
('gene/protein', 'target', 'drug'),
|
220 |
+
('gene/protein', 'transporter', 'drug'),
|
221 |
+
('molecular_function', 'interacts with', 'exposure'),
|
222 |
+
('molecular_function', 'interacts with', 'gene/protein'),
|
223 |
+
('molecular_function', 'parent-child', 'molecular_function'),
|
224 |
+
('pathway', 'interacts with', 'gene/protein'),
|
225 |
+
('pathway', 'parent-child', 'pathway')]
|
226 |
+
|
227 |
+
"""
|
228 |
+
|
229 |
+
output_format = """
|
230 |
+
**Output Format**
|
231 |
+
Triplets: []
|
232 |
+
Restriction: {}
|
233 |
+
Target: ""
|
234 |
+
|
235 |
+
"""
|
236 |
+
|
237 |
+
demonstrations = """
|
238 |
+
Here are some examples:
|
239 |
+
**Question 1:**
|
240 |
+
Search for conditions that lack any associated treatment medications and have a connection to the formation of Onion bulb structures.
|
241 |
+
|
242 |
+
**Answer 1:**
|
243 |
+
Triplets: [["effect/phenotype", "phenotype present", "disease"], ["drug", "off-label use", "disease"]]
|
244 |
+
Restriction: {'effect/phenotype': ['formation of Onion bulb structures']}
|
245 |
+
Target: disease
|
246 |
+
|
247 |
+
**Question 2:**
|
248 |
+
What drug should be avoided for focal hand dystonia and also targets the ORM2 gene/protein?
|
249 |
+
|
250 |
+
**Answer 2:**
|
251 |
+
Triplets: [["disease", "contraindication", "drug"], ["drug", "target", "gene/protein"]]
|
252 |
+
Restriction: {"disease": ["focal hand dystonia"], "gene/protein": ["ORM2"]}
|
253 |
+
Target: drug
|
254 |
+
|
255 |
+
**Question 3:**
|
256 |
+
Could my hip muscle weakness be a sign of the mitochondrial disease my mother has, or another related condition?
|
257 |
+
Triplets: [["disease", "parent-child", "disease"], ["effect/phenotype", "phenotype present", "disease"]]
|
258 |
+
|
259 |
+
**Answer 3:**
|
260 |
+
Restriction: {"anatomy": ["hip muscle"], "disease": ["mitochondrial disease my mother has", "another related condition"], "effect/phenotype": ["weakness"]}
|
261 |
+
Target: disease
|
262 |
+
|
263 |
+
**Question 4:**
|
264 |
+
Can you find the genes and proteins involved with the activity of dolichyl-phosphate-mannose-dependent alpha-1,6-mannosyltransferase?
|
265 |
+
|
266 |
+
**Answer 4:**
|
267 |
+
Triplets: [["molecular_function", "interacts with", "gene/protein"]]
|
268 |
+
Restriction: {"molecular_function": ["dolichyl-phosphate-mannose-dependent alpha-1,6-mannosyltransferase"]}
|
269 |
+
Target: gene/protein
|
270 |
+
"""
|
271 |
+
|
272 |
+
sys_content = sys_prompt + triplets + output_format + demonstrations
|
273 |
+
|
274 |
+
return sys_content
|
Planning/data/train_eval.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Description: This file is used to train and evaluate the llama model.
|
3 |
+
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from unsloth import FastLanguageModel
|
7 |
+
from trl import SFTTrainer
|
8 |
+
from transformers import TrainingArguments, DataCollatorForSeq2Seq
|
9 |
+
from unsloth import is_bfloat16_supported
|
10 |
+
from datasets import load_dataset
|
11 |
+
from unsloth.chat_templates import train_on_responses_only
|
12 |
+
from unsloth.chat_templates import get_chat_template
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
from transformers import TextStreamer
|
15 |
+
import ast
|
16 |
+
import contractions
|
17 |
+
import re
|
18 |
+
from utils import remove_inner_single_quotes
|
19 |
+
|
20 |
+
# *****load model and tokenizer*****
|
21 |
+
max_seq_length = 2048
|
22 |
+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
23 |
+
load_in_4bit = True # Use 4bit quantization to reduce memory usage.
|
24 |
+
|
25 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
26 |
+
model_name = "unsloth/Llama-3.2-3B-Instruct",
|
27 |
+
max_seq_length = max_seq_length, # specify the maximum length of input the model can accept
|
28 |
+
dtype = dtype,
|
29 |
+
load_in_4bit = load_in_4bit,
|
30 |
+
)
|
31 |
+
|
32 |
+
# add LoRA adapters so we only need to update 1 to 10% of all parameters!
|
33 |
+
model = FastLanguageModel.get_peft_model(
|
34 |
+
model,
|
35 |
+
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
36 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",# Reference: https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing#scrollTo=2eSvM9zX_2d3
|
37 |
+
"gate_proj", "up_proj", "down_proj",],
|
38 |
+
lora_alpha = 16,
|
39 |
+
lora_dropout = 0, # Supports any, but = 0 is optimized
|
40 |
+
bias = "none", # Supports any, but = "none" is optimized
|
41 |
+
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
42 |
+
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
43 |
+
random_state = 3407,
|
44 |
+
use_rslora = False, # We support rank stabilized LoRA
|
45 |
+
loftq_config = None, # And LoftQ
|
46 |
+
)
|
47 |
+
|
48 |
+
tokenizer = get_chat_template(
|
49 |
+
tokenizer,
|
50 |
+
chat_template = "llama-3.1",
|
51 |
+
)
|
52 |
+
|
53 |
+
def formatting_prompts_func(data):
|
54 |
+
convos = data['conversations']
|
55 |
+
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
|
56 |
+
|
57 |
+
return { "text" : texts, }
|
58 |
+
|
59 |
+
def train(train_dataset, val_dataset):
|
60 |
+
# format the dataset
|
61 |
+
texts_train = train_dataset.map(formatting_prompts_func, batched=True)
|
62 |
+
print(texts_train)
|
63 |
+
texts_val = val_dataset.map(formatting_prompts_func, batched=True)
|
64 |
+
print(texts_val)
|
65 |
+
|
66 |
+
# load trainer
|
67 |
+
trainer = SFTTrainer(
|
68 |
+
model = model,
|
69 |
+
tokenizer = tokenizer,
|
70 |
+
train_dataset = texts_train,
|
71 |
+
eval_dataset= texts_val,
|
72 |
+
dataset_text_field = "text",
|
73 |
+
max_seq_length = max_seq_length,
|
74 |
+
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), # A function or mechanism to batch and prepare data during training and evaluation
|
75 |
+
dataset_num_proc = 2, # number of processes to use for data preprocessing
|
76 |
+
packing = False, # Can make training 5x faster for short sequences.
|
77 |
+
args = TrainingArguments(
|
78 |
+
per_device_train_batch_size = 4,
|
79 |
+
gradient_accumulation_steps = 8,
|
80 |
+
warmup_steps = 5,
|
81 |
+
num_train_epochs = 100, # Set this for 1 full training run. # TODO: increase the value
|
82 |
+
max_steps = 1000, # TODO: increase the value
|
83 |
+
learning_rate = 2e-4,
|
84 |
+
fp16 = not is_bfloat16_supported(),
|
85 |
+
bf16 = is_bfloat16_supported(),
|
86 |
+
logging_steps = 1,
|
87 |
+
optim = "adamw_8bit",
|
88 |
+
weight_decay = 0.01,
|
89 |
+
lr_scheduler_type = "linear",
|
90 |
+
seed = 3407,
|
91 |
+
output_dir = "outputs",
|
92 |
+
do_eval=True,
|
93 |
+
report_to='wandb',
|
94 |
+
evaluation_strategy="epoch", # Specifies that evaluations will happen at the end of each epoch, using texts_val for metrics calculation. If we set the evaluation strategy without passing evaluation dataset, there is an error.
|
95 |
+
save_strategy="epoch", # Save checkpoints based on epoch.
|
96 |
+
load_best_model_at_end=True, # Load the best one from the disk to memory. Otherwise, the model is still the one trained after last epoch.
|
97 |
+
metric_for_best_model="loss", # Evaluation metric on evaluation set. Make sure the metric is an option included in TrainingArguments.
|
98 |
+
greater_is_better=False, # The metric is "greater_is_better".
|
99 |
+
save_total_limit=1 # How many checkpoints will be saved.
|
100 |
+
),
|
101 |
+
)
|
102 |
+
|
103 |
+
# train only on outputs and ignore the loss of user's inputs
|
104 |
+
trainer = train_on_responses_only(
|
105 |
+
trainer,
|
106 |
+
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
107 |
+
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
108 |
+
)
|
109 |
+
|
110 |
+
trainer_stats = trainer.train()
|
111 |
+
|
112 |
+
# save the model
|
113 |
+
checkpoint_dir = f"./checkpoints/{dataset_name}"
|
114 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
115 |
+
checkpoint_path = os.path.join(checkpoint_dir, f"lora_model")
|
116 |
+
model.save_pretrained(checkpoint_path) # Local saving
|
117 |
+
tokenizer.save_pretrained(checkpoint_path)
|
118 |
+
|
119 |
+
print("Training completed.")
|
120 |
+
|
121 |
+
return checkpoint_path
|
122 |
+
|
123 |
+
def evaluate(test_dataset, checkpoint_path):
|
124 |
+
|
125 |
+
# load model
|
126 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
127 |
+
model_name = checkpoint_path,
|
128 |
+
max_seq_length = max_seq_length, # specify the maximum length of input the model can accept
|
129 |
+
dtype = dtype,
|
130 |
+
load_in_4bit = load_in_4bit,
|
131 |
+
)
|
132 |
+
|
133 |
+
FastLanguageModel.for_inference(model)
|
134 |
+
|
135 |
+
# evaluate
|
136 |
+
acc = 0
|
137 |
+
for idx, dp in enumerate(test_dataset): # TODO: batch
|
138 |
+
message = dp['conversations'][0]
|
139 |
+
label = dp['conversations'][1]
|
140 |
+
assert label['role'] == "assistant"
|
141 |
+
|
142 |
+
# format the input
|
143 |
+
inputs = tokenizer.apply_chat_template([message], tokenize = True, add_generation_prompt = True, return_tensors = 'pt').to('cuda')
|
144 |
+
print(f"222, {inputs}")
|
145 |
+
|
146 |
+
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
|
147 |
+
outputs = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,
|
148 |
+
use_cache = True, temperature = 1.5, min_p = 0.1)
|
149 |
+
outputs = tokenizer.batch_decode(outputs)
|
150 |
+
parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
151 |
+
# print(f"111, {parts}")
|
152 |
+
results = parts[1].strip("<|eot_id|>")
|
153 |
+
results = contractions.fix(results)
|
154 |
+
try:
|
155 |
+
results = ast.literal_eval(results)
|
156 |
+
except:
|
157 |
+
try:
|
158 |
+
results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results)
|
159 |
+
results = ast.literal_eval(results)
|
160 |
+
except:
|
161 |
+
results = {
|
162 |
+
"Metapath":"",
|
163 |
+
"Restriction":{},
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
pred_metapath = results['Metapath']
|
169 |
+
|
170 |
+
ground_metapath = ast.literal_eval(label['content'])['Metapath']
|
171 |
+
if pred_metapath == ground_metapath:
|
172 |
+
acc += 1
|
173 |
+
print(f"Prediction: {pred_metapath}")
|
174 |
+
print(f"Ground truth: {ground_metapath}")
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
print(f"Accuracy: {acc / len(test_dataset)}")
|
179 |
+
|
180 |
+
|
181 |
+
def main(dataset_name, model_name):
|
182 |
+
# *****load dataset*****
|
183 |
+
data_dir = f"./data/finetune"
|
184 |
+
data_path = os.path.join(data_dir, f"{dataset_name}/llama_ft.jsonl")
|
185 |
+
# data_path = os.path.join(data_dir, f"{dataset_name}/llama_ft_{model_name}.jsonl")
|
186 |
+
dataset = load_dataset("json", data_files=data_path)
|
187 |
+
dataset = dataset['train']
|
188 |
+
|
189 |
+
# Add an index column to keep track of original indices
|
190 |
+
dataset = dataset.add_column("index", list(range(len(dataset))))
|
191 |
+
|
192 |
+
# Perform train-test split
|
193 |
+
train_test = dataset.train_test_split(test_size=0.2, seed=42)
|
194 |
+
val_test = train_test['test'].train_test_split(test_size=0.5, seed=42)
|
195 |
+
|
196 |
+
# Extract the datasets
|
197 |
+
train_dataset = train_test['train']
|
198 |
+
val_dataset = val_test['train']
|
199 |
+
test_dataset = val_test['test']
|
200 |
+
|
201 |
+
# Remove the index column if not needed
|
202 |
+
train_dataset = train_dataset.remove_columns(['index'])
|
203 |
+
val_dataset = val_dataset.remove_columns(['index'])
|
204 |
+
test_dataset = test_dataset.remove_columns(['index'])
|
205 |
+
|
206 |
+
# *****train and evaluate*****
|
207 |
+
checkpoint_path = train(train_dataset, val_dataset)
|
208 |
+
|
209 |
+
|
210 |
+
checkpoint_path = f"./checkpoints/{dataset_name}/lora_model"
|
211 |
+
# checkpoint_path = f"./checkpoints/{dataset_name}/lora_model_{model_name}"
|
212 |
+
evaluate(test_dataset, checkpoint_path)
|
213 |
+
|
214 |
+
if __name__ == "__main__":
|
215 |
+
dataset_name_list = ['mag', 'amazon', 'prime']
|
216 |
+
model_names = ["4o"]
|
217 |
+
for dataset_name in dataset_name_list:
|
218 |
+
for model_name in model_names:
|
219 |
+
main(dataset_name, model_name)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
Planning/model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch.nn as nn
|
4 |
+
import ast
|
5 |
+
from unsloth import FastLanguageModel
|
6 |
+
from transformers import TextStreamer
|
7 |
+
import contractions
|
8 |
+
import re
|
9 |
+
from Planning.utils import remove_inner_single_quotes
|
10 |
+
|
11 |
+
|
12 |
+
class Planner(nn.Module):
|
13 |
+
def __init__(self, dataset_name):
|
14 |
+
super(Planner, self).__init__()
|
15 |
+
self.dataset_name = dataset_name
|
16 |
+
self.checkpoint_path = f"Planning/checkpoints/{dataset_name}/lora_model/"
|
17 |
+
self.max_seq_length = 2048
|
18 |
+
self.dtype = None
|
19 |
+
self.load_in_4bit = True
|
20 |
+
|
21 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
22 |
+
model_name = self.checkpoint_path,
|
23 |
+
max_seq_length = self.max_seq_length,
|
24 |
+
dtype = self.dtype,
|
25 |
+
load_in_4bit = self.load_in_4bit
|
26 |
+
)
|
27 |
+
|
28 |
+
FastLanguageModel.for_inference(model)
|
29 |
+
|
30 |
+
self.model = model
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
|
33 |
+
def forward(self, query):
|
34 |
+
message = {'content': query, 'role': 'user'}
|
35 |
+
inputs = self.tokenizer.apply_chat_template(
|
36 |
+
[message],
|
37 |
+
tokenize = True,
|
38 |
+
add_generation_prompt = True, # Must add for generation
|
39 |
+
return_tensors = "pt",
|
40 |
+
).to("cuda")
|
41 |
+
|
42 |
+
text_streamer = TextStreamer(self.tokenizer, skip_prompt = True)
|
43 |
+
outputs = self.model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128, # max_new_tokens is the maximum number of new tokens generated beyond the input
|
44 |
+
use_cache = True, temperature = 1.5, min_p = 0.1) # min_p is a cumulative probability, which makes the generation more diverse
|
45 |
+
|
46 |
+
|
47 |
+
outputs = self.tokenizer.batch_decode(outputs)
|
48 |
+
parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
49 |
+
|
50 |
+
if len(parts) > 1:
|
51 |
+
results = parts[1].replace("<|eot_id|>", "")
|
52 |
+
else:
|
53 |
+
raise ValueError
|
54 |
+
|
55 |
+
|
56 |
+
# ******* special processing for prime dataset
|
57 |
+
if self.dataset_name == 'prime':
|
58 |
+
try:
|
59 |
+
# Parse the string using ast.literal_eval
|
60 |
+
parsed_dict = ast.literal_eval(results)
|
61 |
+
|
62 |
+
return parsed_dict
|
63 |
+
except (SyntaxError, ValueError) as e:
|
64 |
+
print(f"Error parsing the string: {e}")
|
65 |
+
return {
|
66 |
+
"Metapath": "",
|
67 |
+
"Restriction": {}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
results = contractions.fix(results)
|
72 |
+
|
73 |
+
try:
|
74 |
+
results = ast.literal_eval(results)
|
75 |
+
except:
|
76 |
+
print(f"Fail")
|
77 |
+
try:
|
78 |
+
results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results) # TODO: need optimize
|
79 |
+
results = ast.literal_eval(results)
|
80 |
+
except:
|
81 |
+
results = {
|
82 |
+
"Metapath": "",
|
83 |
+
"Restriction": {},
|
84 |
+
|
85 |
+
}
|
86 |
+
rg = results
|
87 |
+
|
88 |
+
|
89 |
+
return rg
|
Planning/utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def remove_inner_single_quotes(match):
|
2 |
+
content = match.group(1) # Get the string inside ['']
|
3 |
+
cleaned_content = content.replace("'", "") # Remove inner single quotes
|
4 |
+
return f"['{cleaned_content}"
|
Reasoning/__pycache__/mor4node.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
Reasoning/__pycache__/mor4node_copy.cpython-311.pyc
ADDED
Binary file (15.8 kB). View file
|
|
Reasoning/__pycache__/mor4path.cpython-311.pyc
ADDED
Binary file (22.7 kB). View file
|
|
Reasoning/__pycache__/ptp_mor4node.cpython-311.pyc
ADDED
Binary file (21 kB). View file
|
|
Reasoning/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (4.99 kB). View file
|
|
Reasoning/mor4path.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
# Get the absolute path of the current script
|
4 |
+
current_file = Path(__file__).resolve()
|
5 |
+
project_root = current_file.parents[2]
|
6 |
+
# Add the project root to the system path
|
7 |
+
sys.path.append(str(project_root))
|
8 |
+
|
9 |
+
from .utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length
|
10 |
+
from models.model import ModelForSTaRKQA
|
11 |
+
|
12 |
+
|
13 |
+
class MOR4Path(ModelForSTaRKQA):
|
14 |
+
def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100):
|
15 |
+
super(MOR4Path, self).__init__(skb)
|
16 |
+
self.dataset_name = dataset_name
|
17 |
+
self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb)
|
18 |
+
self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb)
|
19 |
+
# self.scorer = self.text_retriever
|
20 |
+
self.topk = topk
|
21 |
+
self.node_type_list = skb.node_type_lst()
|
22 |
+
self.edge_type_list = skb.rel_type_lst()
|
23 |
+
if self.dataset_name == "prime":
|
24 |
+
self.tp_list = skb.get_tuples()
|
25 |
+
self.target_type_list = skb.candidate_types
|
26 |
+
else:
|
27 |
+
self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()}
|
28 |
+
self.target_type_list = ['paper' if dataset_name == 'mag' else 'product']
|
29 |
+
|
30 |
+
self.skb = skb
|
31 |
+
self.ini_k = 5 # topk for initial retrieval
|
32 |
+
self.mor_k = 10 # topk for textual retrieval in MOR
|
33 |
+
self.mor_count = 0
|
34 |
+
self.num_negs = 200
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def rg2routes(self, rg):
|
39 |
+
"""
|
40 |
+
input: rg: {"Metapath": "", "Restriction": {}}
|
41 |
+
output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
|
42 |
+
"""
|
43 |
+
# parse rg
|
44 |
+
metapath = rg["Metapath"]
|
45 |
+
if isinstance(rg["Metapath"], list):
|
46 |
+
routes = rg["Metapath"]
|
47 |
+
elif isinstance(rg["Metapath"], str):
|
48 |
+
routes = parse_metapath(metapath)
|
49 |
+
else:
|
50 |
+
return None
|
51 |
+
|
52 |
+
return routes
|
53 |
+
|
54 |
+
def check_valid(self, routes, rg):
|
55 |
+
# check the length of routes
|
56 |
+
if not routes:
|
57 |
+
# raise ValueError(f"Empty routes: {routes}")
|
58 |
+
return None
|
59 |
+
|
60 |
+
if len(routes) == 1 and len(routes[0]) == 1: # single node, directly do text retrieval
|
61 |
+
return None
|
62 |
+
|
63 |
+
|
64 |
+
# Step 1: Filter routes by target type
|
65 |
+
target_type_valid_routes = [
|
66 |
+
route for route in routes if route[-1] in self.target_type_list
|
67 |
+
]
|
68 |
+
if not target_type_valid_routes:
|
69 |
+
|
70 |
+
return None
|
71 |
+
|
72 |
+
# Step 2: Filter routes by node and edge type
|
73 |
+
type_valid_routes = [
|
74 |
+
route
|
75 |
+
for route in target_type_valid_routes
|
76 |
+
if all(
|
77 |
+
node in self.node_type_list or node in self.edge_type_list
|
78 |
+
for node in route
|
79 |
+
)
|
80 |
+
]
|
81 |
+
if not type_valid_routes:
|
82 |
+
|
83 |
+
return None
|
84 |
+
|
85 |
+
# Step 3: Check existence of relations
|
86 |
+
relation_valid_routes = []
|
87 |
+
for route in type_valid_routes:
|
88 |
+
if self.dataset_name == "prime":
|
89 |
+
if len(route) < 3:
|
90 |
+
continue
|
91 |
+
triplets = [
|
92 |
+
(route[i], route[i + 1], route[i + 2])
|
93 |
+
for i in range(0, len(route) - 2, 2)
|
94 |
+
]
|
95 |
+
|
96 |
+
if all(tp in self.tp_list for tp in triplets) and all(len(tp) == 3 for tp in triplets): # and all length of triplets is 3
|
97 |
+
relation_valid_routes.append(route)
|
98 |
+
else:
|
99 |
+
pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)]
|
100 |
+
if all(tp in self.tp_dict.keys() for tp in pairs):
|
101 |
+
relations = [self.tp_dict[tp] for tp in pairs]
|
102 |
+
|
103 |
+
# make route with relations
|
104 |
+
new_route = []
|
105 |
+
for i in range(len(relations)):
|
106 |
+
new_route.append(pairs[i][0])
|
107 |
+
new_route.append(relations[i])
|
108 |
+
new_route.append(pairs[-1][-1])
|
109 |
+
|
110 |
+
relation_valid_routes.append(new_route)
|
111 |
+
|
112 |
+
if not relation_valid_routes:
|
113 |
+
|
114 |
+
return None
|
115 |
+
|
116 |
+
return relation_valid_routes
|
117 |
+
|
118 |
+
def get_candidates4route(self, query, q_id, route, restriction):
|
119 |
+
# initialization
|
120 |
+
|
121 |
+
ini_node_type = route[0]
|
122 |
+
|
123 |
+
try:
|
124 |
+
type_restr = "".join(restriction[ini_node_type])
|
125 |
+
except:
|
126 |
+
type_restr = ""
|
127 |
+
|
128 |
+
ini_dict = self.text_retriever.retrieve(query + " " + type_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type)
|
129 |
+
current_node_ids = list(ini_dict.keys())
|
130 |
+
|
131 |
+
|
132 |
+
# initialize the bm_vector_dict
|
133 |
+
bm_vector_dict = {key: [value] for key, value in ini_dict.items()}
|
134 |
+
|
135 |
+
# initilization for paths
|
136 |
+
paths = {}
|
137 |
+
for c_id in current_node_ids:
|
138 |
+
paths[c_id] = [c_id]
|
139 |
+
|
140 |
+
# loop
|
141 |
+
hops = len(route)
|
142 |
+
# for hop/layer
|
143 |
+
for hop in range(0, hops-2, 2):
|
144 |
+
new_paths = {}
|
145 |
+
|
146 |
+
cur_node_type = route[hop]
|
147 |
+
next_node_type = route[hop+2]
|
148 |
+
edge_type = route[hop+1]
|
149 |
+
next_node_ids = []
|
150 |
+
|
151 |
+
new_vector_dict = {}
|
152 |
+
# for node
|
153 |
+
for node_id in current_node_ids:
|
154 |
+
neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type)
|
155 |
+
next_node_ids.extend(neighbor_ids)
|
156 |
+
|
157 |
+
# ***** update paths and score_vector_dict *****
|
158 |
+
for neighbor_id in neighbor_ids:
|
159 |
+
if neighbor_id not in new_paths.keys(): # only add new node
|
160 |
+
new_paths[neighbor_id] = paths[node_id] + [neighbor_id]
|
161 |
+
new_vector_dict[neighbor_id] = bm_vector_dict[node_id] + [-1] # -1 for padding
|
162 |
+
|
163 |
+
|
164 |
+
bm_vector_dict = new_vector_dict
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
# ***** layer text retrieval *****
|
169 |
+
# if there is restriction for the next node, add text_retriever
|
170 |
+
if next_node_type in restriction.keys() and len(restriction[next_node_type]) > 0 and restriction[next_node_type] != [""]:
|
171 |
+
try:
|
172 |
+
|
173 |
+
retrieve_dict = self.text_retriever.retrieve(query+" "+"".join(restriction[next_node_type]), q_id=q_id, topk=self.mor_k, node_type=route[hop+2])
|
174 |
+
|
175 |
+
new_query = query+ " " + "".join(restriction[next_node_type])
|
176 |
+
|
177 |
+
# take union
|
178 |
+
next_node_ids.extend(list(set(retrieve_dict.keys())))
|
179 |
+
|
180 |
+
# ***** update paths and bm_vector_dict *****
|
181 |
+
for c_id in retrieve_dict.keys():
|
182 |
+
if c_id not in new_paths.keys():
|
183 |
+
new_paths[c_id] = [c_id]
|
184 |
+
bm_vector_dict[c_id] = [retrieve_dict[c_id]]
|
185 |
+
|
186 |
+
except:
|
187 |
+
pass
|
188 |
+
|
189 |
+
|
190 |
+
paths = new_paths
|
191 |
+
current_node_ids = list(set(next_node_ids))
|
192 |
+
|
193 |
+
|
194 |
+
candidates = current_node_ids
|
195 |
+
|
196 |
+
self.paths.append(paths)
|
197 |
+
self.bm_vector_dict.append(bm_vector_dict)
|
198 |
+
|
199 |
+
|
200 |
+
return candidates
|
201 |
+
|
202 |
+
def merge_candidate_pools(self, non_empty_candidates_lists):
|
203 |
+
|
204 |
+
# if only one non-empy candidates list left, return it as a set
|
205 |
+
if len(non_empty_candidates_lists) == 1:
|
206 |
+
return set(non_empty_candidates_lists[0])
|
207 |
+
|
208 |
+
# find the intersection candidates ids
|
209 |
+
result = set(non_empty_candidates_lists[0])
|
210 |
+
for lst in non_empty_candidates_lists[1:]:
|
211 |
+
result.intersection_update(lst)
|
212 |
+
|
213 |
+
# if the intersection is empty, return the union of all candidates
|
214 |
+
if len(result) == 0:
|
215 |
+
result = set()
|
216 |
+
for lst in non_empty_candidates_lists:
|
217 |
+
result.update(lst)
|
218 |
+
|
219 |
+
return list(result)
|
220 |
+
|
221 |
+
def get_mor_candidates(self, query, q_id, valid_routes, restriction):
|
222 |
+
|
223 |
+
# Step 1: Get candidates for each route
|
224 |
+
candidates_pool = []
|
225 |
+
for route in valid_routes:
|
226 |
+
if route[0] in restriction.keys() and len(restriction[route[0]]) > 0:
|
227 |
+
candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) # topk is the candidates retrieved from textual retriever
|
228 |
+
|
229 |
+
|
230 |
+
# remove empty lists from candidates
|
231 |
+
non_empty_candidates_lists = [lst for lst in candidates_pool if lst]
|
232 |
+
if len(non_empty_candidates_lists) == 0:
|
233 |
+
|
234 |
+
return {}
|
235 |
+
|
236 |
+
|
237 |
+
# Step 2: Combine candidates from different routes, try intersection first, then union
|
238 |
+
candidates = self.merge_candidate_pools(non_empty_candidates_lists) # candidates is a list
|
239 |
+
|
240 |
+
# step 3: score the candidates, ini to -1
|
241 |
+
pred_dict = dict(zip(candidates, [-1]*len(candidates)))
|
242 |
+
# print(f"111, {pred_dict}")
|
243 |
+
|
244 |
+
return pred_dict
|
245 |
+
|
246 |
+
def check_topk(self, query, q_id, pred_dict):
|
247 |
+
|
248 |
+
missing = self.topk - len(set(pred_dict.keys()))
|
249 |
+
if missing > 0:
|
250 |
+
added_dict = self.text_retriever.retrieve(query, q_id, topk=self.topk+20, node_type=self.target_type_list) # +20 make it more safe
|
251 |
+
available_nodes = {key: value for key, value in added_dict.items() if key not in pred_dict.keys()}
|
252 |
+
sorted_available_nodes = sorted(available_nodes.items(), key=lambda x: x[1], reverse=True)
|
253 |
+
# Select only the required number of nodes to fill the missing slots
|
254 |
+
selected_nodes = dict(sorted_available_nodes[:missing])
|
255 |
+
|
256 |
+
# Update pred_dict with the selected nodes
|
257 |
+
pred_dict.update(selected_nodes)
|
258 |
+
|
259 |
+
# updata paths
|
260 |
+
for node_id in selected_nodes.keys():
|
261 |
+
self.paths[node_id] = [node_id]
|
262 |
+
|
263 |
+
# update bm_vector_dict
|
264 |
+
new_bm_vector_dict = {key: [value] for key, value in selected_nodes.items()}
|
265 |
+
self.bm_vector_dict.update(new_bm_vector_dict)
|
266 |
+
|
267 |
+
scored_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(pred_dict.keys()))
|
268 |
+
|
269 |
+
if len(scored_dict) > self.topk:
|
270 |
+
# initiliaze the new_paths
|
271 |
+
new_paths = {}
|
272 |
+
|
273 |
+
# Select the top-k nodes based on the scores
|
274 |
+
sorted_scored_dict = sorted(scored_dict.items(), key=lambda x: x[1], reverse=True)
|
275 |
+
scored_dict = dict(sorted_scored_dict[:self.topk])
|
276 |
+
|
277 |
+
# update paths
|
278 |
+
for node_id in scored_dict.keys():
|
279 |
+
new_paths[node_id] = self.paths[node_id]
|
280 |
+
|
281 |
+
self.paths = new_paths
|
282 |
+
|
283 |
+
# update bm_vector_dict
|
284 |
+
new_bm_vector_dict = {node_id: self.bm_vector_dict[node_id] for node_id in scored_dict.keys()}
|
285 |
+
self.bm_vector_dict = new_bm_vector_dict
|
286 |
+
|
287 |
+
|
288 |
+
return scored_dict
|
289 |
+
|
290 |
+
|
291 |
+
# check fixed negtopk
|
292 |
+
def check_negtopk(self, query, q_id, pred_dict, ans_ids):
|
293 |
+
# check the positive nodes
|
294 |
+
pos_ids = [node_id for node_id in ans_ids if node_id in pred_dict.keys()]
|
295 |
+
pos_dict = {key: value for key, value in pred_dict.items() if key in pos_ids}
|
296 |
+
neg_ids = pred_dict.keys() - set(pos_ids)
|
297 |
+
neg_dict = {key: value for key, value in pred_dict.items() if key in neg_ids}
|
298 |
+
|
299 |
+
# check the number of negative nodes
|
300 |
+
missing = self.num_negs - len(neg_ids)
|
301 |
+
|
302 |
+
if missing > 0:
|
303 |
+
|
304 |
+
added_dict = self.text_retriever.retrieve(query, q_id, topk=self.num_negs+200, node_type=self.target_type_list) # +20 make it more safe
|
305 |
+
available_nodes = {key: value for key, value in added_dict.items() if key not in pred_dict.keys() and key not in ans_ids}
|
306 |
+
sorted_available_nodes = sorted(available_nodes.items(), key=lambda x: x[1], reverse=True)
|
307 |
+
# Select only the required number of nodes to fill the missing slots
|
308 |
+
selected_nodes = dict(sorted_available_nodes[:missing])
|
309 |
+
|
310 |
+
# Update pred_dict with the selected nodes
|
311 |
+
neg_dict.update(selected_nodes)
|
312 |
+
|
313 |
+
# updata paths
|
314 |
+
for node_id in selected_nodes.keys():
|
315 |
+
self.paths[node_id] = [node_id]
|
316 |
+
|
317 |
+
# update bm_vector_dict
|
318 |
+
new_bm_vector_dict = {key: [value] for key, value in selected_nodes.items()}
|
319 |
+
self.bm_vector_dict.update(new_bm_vector_dict)
|
320 |
+
|
321 |
+
|
322 |
+
scored_neg_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(neg_dict.keys()))
|
323 |
+
if pos_dict:
|
324 |
+
scored_pos_dict = self.scorer.score(query, q_id=q_id, candidate_ids=list(pos_dict.keys()))
|
325 |
+
else:
|
326 |
+
scored_pos_dict = {}
|
327 |
+
|
328 |
+
if len(scored_neg_dict) > self.num_negs:
|
329 |
+
# Select the top-k nodes based on the scores
|
330 |
+
sorted_scored_neg_dict = sorted(scored_neg_dict.items(), key=lambda x: x[1], reverse=True)
|
331 |
+
scored_neg_dict = dict(sorted_scored_neg_dict[:self.num_negs])
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
scored_neg_dict.update(scored_pos_dict)
|
336 |
+
scored_dict = scored_neg_dict
|
337 |
+
print(len(scored_dict))
|
338 |
+
|
339 |
+
# update paths
|
340 |
+
new_paths = {}
|
341 |
+
for node_id in scored_dict.keys():
|
342 |
+
new_paths[node_id] = self.paths[node_id]
|
343 |
+
self.paths = new_paths
|
344 |
+
|
345 |
+
# update bm_vector_dict
|
346 |
+
new_bm_vector_dict = {node_id: self.bm_vector_dict[node_id] for node_id in scored_dict.keys()}
|
347 |
+
self.bm_vector_dict = new_bm_vector_dict
|
348 |
+
|
349 |
+
|
350 |
+
return scored_dict
|
351 |
+
|
352 |
+
|
353 |
+
def forward(self, query, q_id, ans_ids, rg, args):
|
354 |
+
|
355 |
+
self.paths = []
|
356 |
+
self.bm_vector_dict = []
|
357 |
+
self.ada_score = {}
|
358 |
+
# ***** Structural Retrieval *****
|
359 |
+
|
360 |
+
# reasoning grpah to routes
|
361 |
+
if self.dataset_name == "prime":
|
362 |
+
routes = rg["Metapath"]
|
363 |
+
else:
|
364 |
+
routes = self.rg2routes(rg)
|
365 |
+
|
366 |
+
# check valid
|
367 |
+
valid_routes = self.check_valid(routes, rg) # add check for restriction
|
368 |
+
|
369 |
+
if valid_routes is None:
|
370 |
+
# do textual retrieval
|
371 |
+
pred_dict = self.text_retriever.retrieve(query, q_id, topk=self.topk, node_type=self.target_type_list)
|
372 |
+
|
373 |
+
# update bm_vector_dict
|
374 |
+
self.bm_vector_dict = {key: [value] for key, value in pred_dict.items()}
|
375 |
+
|
376 |
+
else:
|
377 |
+
# truncate the valid_routes
|
378 |
+
if self.dataset_name == "prime":
|
379 |
+
pass
|
380 |
+
else:
|
381 |
+
valid_routes = [route[-5:] for route in valid_routes]
|
382 |
+
|
383 |
+
|
384 |
+
# do structural retrieval
|
385 |
+
restriction = rg["Restriction"]
|
386 |
+
pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction)
|
387 |
+
self.mor_count += 1
|
388 |
+
|
389 |
+
|
390 |
+
# **** combine paths ****
|
391 |
+
if self.paths:
|
392 |
+
self.paths = combine_dicts(self.paths, pred_dict=pred_dict) # return dict
|
393 |
+
else:
|
394 |
+
self.paths = {}
|
395 |
+
for node_id in pred_dict.keys():
|
396 |
+
self.paths[node_id] = [node_id]
|
397 |
+
|
398 |
+
# ***** combine bm_vector_dict *****
|
399 |
+
if isinstance(self.bm_vector_dict, list):
|
400 |
+
self.bm_vector_dict = combine_dicts(self.bm_vector_dict, pred_dict=pred_dict)
|
401 |
+
|
402 |
+
# **** fix neg for training; fix candidates for testing ****
|
403 |
+
if args.mod == "train":
|
404 |
+
# check neg topk
|
405 |
+
pred_dict = self.check_negtopk(query, q_id, pred_dict, ans_ids)
|
406 |
+
else:
|
407 |
+
# check topk
|
408 |
+
pred_dict = self.check_topk(query, q_id, pred_dict)
|
409 |
+
|
410 |
+
|
411 |
+
# **** length padding and truncate *****
|
412 |
+
if self.dataset_name != "prime":
|
413 |
+
self.paths = fix_length(self.paths)
|
414 |
+
|
415 |
+
if len(self.paths) != len(pred_dict):
|
416 |
+
print(f"paths: {self.paths}")
|
417 |
+
print(f"pred_dict: {pred_dict}")
|
418 |
+
raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}")
|
419 |
+
|
420 |
+
output = {
|
421 |
+
"query": query,
|
422 |
+
"pred_dict": pred_dict,
|
423 |
+
"ans_ids": ans_ids,
|
424 |
+
'paths': self.paths,
|
425 |
+
'bm_vector_dict': self.bm_vector_dict,
|
426 |
+
'rg': rg
|
427 |
+
}
|
428 |
+
|
429 |
+
|
430 |
+
return output
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
print(f"Test mor4path")
|
434 |
+
|
435 |
+
|
Reasoning/structural_retriever/__pycache__/stru4path.cpython-311.pyc
ADDED
Binary file (13 kB). View file
|
|
Reasoning/structural_retriever/stru4path.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
input: rg
|
3 |
+
output (fixed 100 candidates, for path-based reranking):
|
4 |
+
{
|
5 |
+
"query": query,
|
6 |
+
"pred_dict": {node_id: score},
|
7 |
+
"ans_ids": [],
|
8 |
+
'paths': {node_id: [node_ids_path]}
|
9 |
+
}
|
10 |
+
|
11 |
+
"""
|
12 |
+
import sys
|
13 |
+
import os
|
14 |
+
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
|
15 |
+
|
16 |
+
from utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length
|
17 |
+
from models.model import ModelForSTaRKQA
|
18 |
+
import time
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class Stru4Path(ModelForSTaRKQA):
|
23 |
+
def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100):
|
24 |
+
super(Stru4Path, self).__init__(skb)
|
25 |
+
self.dataset_name = dataset_name
|
26 |
+
self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb)
|
27 |
+
self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb)
|
28 |
+
# self.scorer = self.text_retriever
|
29 |
+
self.topk = topk
|
30 |
+
self.node_type_list = skb.node_type_lst()
|
31 |
+
self.edge_type_list = skb.rel_type_lst()
|
32 |
+
if self.dataset_name == "prime":
|
33 |
+
self.tp_list = skb.get_tuples()
|
34 |
+
self.target_type_list = skb.candidate_types
|
35 |
+
else:
|
36 |
+
self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()}
|
37 |
+
self.target_type_list = ['paper' if dataset_name == 'mag' else 'product']
|
38 |
+
|
39 |
+
self.skb = skb
|
40 |
+
self.ini_k = 5 # topk for initial retrieval
|
41 |
+
self.stru_count = 0
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def rg2routes(self, rg):
|
47 |
+
"""
|
48 |
+
input: rg: {"Metapath": "", "Restriction": {}}
|
49 |
+
output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
|
50 |
+
"""
|
51 |
+
# parse rg
|
52 |
+
metapath = rg["Metapath"]
|
53 |
+
if isinstance(rg["Metapath"], list):
|
54 |
+
routes = rg["Metapath"]
|
55 |
+
elif isinstance(rg["Metapath"], str):
|
56 |
+
routes = parse_metapath(metapath)
|
57 |
+
else:
|
58 |
+
return None
|
59 |
+
|
60 |
+
return routes
|
61 |
+
|
62 |
+
def check_valid(self, routes, rg):
|
63 |
+
# check the length of routes
|
64 |
+
if not routes:
|
65 |
+
# raise ValueError(f"Empty routes: {routes}")
|
66 |
+
return None
|
67 |
+
|
68 |
+
if len(routes) == 1 and len(routes[0]) == 1: # single node, directly do text retrieval
|
69 |
+
return 1
|
70 |
+
|
71 |
+
# Step 1: Filter routes by target type
|
72 |
+
target_type_valid_routes = [
|
73 |
+
route for route in routes if route[-1] in self.target_type_list
|
74 |
+
]
|
75 |
+
if not target_type_valid_routes:
|
76 |
+
return None
|
77 |
+
|
78 |
+
# Step 2: Filter routes by node and edge type
|
79 |
+
type_valid_routes = [
|
80 |
+
route
|
81 |
+
for route in target_type_valid_routes
|
82 |
+
if all(
|
83 |
+
node in self.node_type_list or node in self.edge_type_list
|
84 |
+
for node in route
|
85 |
+
)
|
86 |
+
]
|
87 |
+
if not type_valid_routes:
|
88 |
+
return None
|
89 |
+
|
90 |
+
# Step 3: Check existence of relations
|
91 |
+
relation_valid_routes = []
|
92 |
+
for route in type_valid_routes:
|
93 |
+
if self.dataset_name == "prime":
|
94 |
+
triplets = [
|
95 |
+
(route[i], route[i + 1], route[i + 2])
|
96 |
+
for i in range(0, len(route) - 2, 2)
|
97 |
+
]
|
98 |
+
|
99 |
+
if all(tp in self.tp_list for tp in triplets):
|
100 |
+
relation_valid_routes.append(route)
|
101 |
+
else:
|
102 |
+
pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)]
|
103 |
+
if all(tp in self.tp_dict.keys() for tp in pairs):
|
104 |
+
relations = [self.tp_dict[tp] for tp in pairs]
|
105 |
+
|
106 |
+
# make route with relations
|
107 |
+
new_route = []
|
108 |
+
for i in range(len(relations)):
|
109 |
+
new_route.append(pairs[i][0])
|
110 |
+
new_route.append(relations[i])
|
111 |
+
new_route.append(pairs[-1][-1])
|
112 |
+
# print(f"222, {new_route}")
|
113 |
+
|
114 |
+
relation_valid_routes.append(new_route)
|
115 |
+
|
116 |
+
if not relation_valid_routes:
|
117 |
+
return None
|
118 |
+
|
119 |
+
return relation_valid_routes
|
120 |
+
|
121 |
+
def get_candidates4route(self, query, q_id, route, restriction):
|
122 |
+
# initialization
|
123 |
+
|
124 |
+
ini_node_type = route[0]
|
125 |
+
|
126 |
+
try:
|
127 |
+
extra_restr = "".join(restriction[ini_node_type])
|
128 |
+
except:
|
129 |
+
extra_restr = ""
|
130 |
+
ini_dict = self.text_retriever.retrieve(query + " " + extra_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type)
|
131 |
+
current_node_ids = list(ini_dict.keys())
|
132 |
+
|
133 |
+
# initilization for paths
|
134 |
+
paths = {}
|
135 |
+
for c_id in current_node_ids:
|
136 |
+
paths[c_id] = [c_id]
|
137 |
+
|
138 |
+
# loop
|
139 |
+
hops = len(route)
|
140 |
+
# for hop/layer
|
141 |
+
for hop in range(0, hops-2, 2):
|
142 |
+
new_paths = {}
|
143 |
+
|
144 |
+
cur_node_type = route[hop]
|
145 |
+
next_node_type = route[hop+2]
|
146 |
+
edge_type = route[hop+1]
|
147 |
+
next_node_ids = []
|
148 |
+
|
149 |
+
# for node
|
150 |
+
for node_id in current_node_ids:
|
151 |
+
neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type)
|
152 |
+
next_node_ids.extend(neighbor_ids)
|
153 |
+
|
154 |
+
# **x*** update paths *****
|
155 |
+
for neighbor_id in neighbor_ids:
|
156 |
+
new_paths[neighbor_id] = paths[node_id] + [neighbor_id]
|
157 |
+
|
158 |
+
|
159 |
+
paths = new_paths
|
160 |
+
|
161 |
+
current_node_ids = list(set(next_node_ids))
|
162 |
+
|
163 |
+
candidates = current_node_ids
|
164 |
+
self.paths.append(paths)
|
165 |
+
|
166 |
+
|
167 |
+
return candidates
|
168 |
+
|
169 |
+
def merge_candidate_pools(self, non_empty_candidates_lists):
|
170 |
+
|
171 |
+
|
172 |
+
# if only one non-empy candidates list left, return it as a set
|
173 |
+
if len(non_empty_candidates_lists) == 1:
|
174 |
+
return set(non_empty_candidates_lists[0])
|
175 |
+
# find the intersection candidates ids
|
176 |
+
result = set(non_empty_candidates_lists[0])
|
177 |
+
for lst in non_empty_candidates_lists[1:]:
|
178 |
+
result.intersection_update(lst)
|
179 |
+
|
180 |
+
# if the intersection is empty, return the union of all candidates
|
181 |
+
if len(result) == 0:
|
182 |
+
result = set()
|
183 |
+
for lst in non_empty_candidates_lists:
|
184 |
+
result.update(lst)
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
return list(result)
|
189 |
+
|
190 |
+
def get_mor_candidates(self, query, q_id, valid_routes, restriction):
|
191 |
+
|
192 |
+
# Step 1: Get candidates for each route
|
193 |
+
candidates_pool = []
|
194 |
+
for route in valid_routes:
|
195 |
+
if route[0] in restriction.keys() and len(restriction[route[0]]) > 0:
|
196 |
+
candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) # topk is the candidates retrieved from textual retriever
|
197 |
+
|
198 |
+
non_empty_candidates_lists = [lst for lst in candidates_pool if lst]
|
199 |
+
if not non_empty_candidates_lists: # no candidates, return empty dict
|
200 |
+
print(f"123, {non_empty_candidates_lists}")
|
201 |
+
|
202 |
+
# raise ValueError("No candidates for any route")
|
203 |
+
return {}
|
204 |
+
|
205 |
+
|
206 |
+
# Step 2: Combine candidates from different routes, try intersection first, then union
|
207 |
+
candidates = self.merge_candidate_pools(candidates_pool) # candidates is a list
|
208 |
+
if not candidates:
|
209 |
+
return {}
|
210 |
+
|
211 |
+
|
212 |
+
# step 3: score the candidates, ini to -1
|
213 |
+
pred_dict = dict(zip(candidates, [-1]*len(candidates)))
|
214 |
+
# print(f"111, {pred_dict}")
|
215 |
+
|
216 |
+
return pred_dict
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def forward(self, query, q_id, ans_ids, rg):
|
221 |
+
|
222 |
+
self.paths = []
|
223 |
+
# ***** Structural Retrieval *****
|
224 |
+
|
225 |
+
# reasoning grpah to routes
|
226 |
+
s_time = time.time()
|
227 |
+
routes = self.rg2routes(rg)
|
228 |
+
# print(f"444, {time.time()-s_time}")
|
229 |
+
|
230 |
+
# check valid
|
231 |
+
s_time = time.time()
|
232 |
+
valid_routes = self.check_valid(routes, rg) # add check for restriction
|
233 |
+
# print(f"555, {time.time()-s_time}")
|
234 |
+
|
235 |
+
if valid_routes is None:
|
236 |
+
# return empty dict
|
237 |
+
return {
|
238 |
+
"query": query,
|
239 |
+
"pred_dict": {},
|
240 |
+
"ans_ids": ans_ids,
|
241 |
+
'paths': {},
|
242 |
+
'query_pattern': rg['Metapath']
|
243 |
+
}
|
244 |
+
elif valid_routes == 1: # TODO: empty string
|
245 |
+
print(f"1234: {valid_routes}")
|
246 |
+
# do text retrieval
|
247 |
+
pred_dict = self.text_retriever.retrieve(query, q_id=q_id, topk=self.topk, node_type=f'{self.target_type_list[0]}')
|
248 |
+
|
249 |
+
else:
|
250 |
+
# do structural retrieval
|
251 |
+
# truncate the valid_routes
|
252 |
+
if self.dataset_name == "prime":
|
253 |
+
pass
|
254 |
+
else:
|
255 |
+
valid_routes = [route[-5:] for route in valid_routes]
|
256 |
+
|
257 |
+
restriction = rg["Restriction"]
|
258 |
+
pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction)
|
259 |
+
self.stru_count += 1
|
260 |
+
|
261 |
+
# **** combine paths ****
|
262 |
+
if self.paths:
|
263 |
+
self.paths = combine_dicts(self.paths, pred_dict=pred_dict) # return dict
|
264 |
+
|
265 |
+
else:
|
266 |
+
self.paths = {}
|
267 |
+
for node_id in pred_dict.keys():
|
268 |
+
self.paths[node_id] = [node_id]
|
269 |
+
|
270 |
+
# if retrieved candidates is empty, return empty dict
|
271 |
+
if not pred_dict:
|
272 |
+
return {
|
273 |
+
"query": query,
|
274 |
+
"pred_dict": {},
|
275 |
+
"ans_ids": ans_ids,
|
276 |
+
'paths': {},
|
277 |
+
'query_pattern': rg['Metapath']
|
278 |
+
}
|
279 |
+
|
280 |
+
# score the candidates
|
281 |
+
pred_dict = self.scorer.score(query, q_id, list(pred_dict.keys()))
|
282 |
+
|
283 |
+
# # **** length padding and truncate *****
|
284 |
+
# self.paths = fix_length(self.paths)
|
285 |
+
|
286 |
+
if len(self.paths) != len(pred_dict):
|
287 |
+
print(f"paths: {self.paths}")
|
288 |
+
print(f"pred_dict: {pred_dict}")
|
289 |
+
raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}")
|
290 |
+
|
291 |
+
output = {
|
292 |
+
"query": query,
|
293 |
+
"pred_dict": pred_dict,
|
294 |
+
"ans_ids": ans_ids,
|
295 |
+
'paths': self.paths,
|
296 |
+
'query_pattern': rg['Metapath'],
|
297 |
+
'rg': rg
|
298 |
+
}
|
299 |
+
|
300 |
+
|
301 |
+
return output
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
Reasoning/text_retrievers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .stark_model import ModelForSTaRKQA
|
2 |
+
from .ada import Ada
|
3 |
+
from .bm25 import BM25
|
4 |
+
from .contriever import Contriever
|
Reasoning/text_retrievers/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (408 Bytes). View file
|
|
Reasoning/text_retrievers/__pycache__/ada.cpython-311.pyc
ADDED
Binary file (5.39 kB). View file
|
|
Reasoning/text_retrievers/__pycache__/bm25.cpython-311.pyc
ADDED
Binary file (7.02 kB). View file
|
|
Reasoning/text_retrievers/__pycache__/contriever.cpython-311.pyc
ADDED
Binary file (5.69 kB). View file
|
|
Reasoning/text_retrievers/__pycache__/stark_model.cpython-311.pyc
ADDED
Binary file (8.37 kB). View file
|
|
Reasoning/text_retrievers/ada.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
input: query, query_id, candidates_ids
|
4 |
+
output: pred_dict: {node_id: similarity}
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
# Get the absolute path of the current script
|
10 |
+
current_file = Path(__file__).resolve()
|
11 |
+
project_root = current_file.parents[2]
|
12 |
+
# Add the project root to the system path
|
13 |
+
sys.path.append(str(project_root))
|
14 |
+
import torch
|
15 |
+
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
|
16 |
+
|
17 |
+
class Ada(ModelForSTaRKQA):
|
18 |
+
def __init__(self, skb, dataset_name, device):
|
19 |
+
super(Ada, self).__init__(skb)
|
20 |
+
self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/"
|
21 |
+
self.query_emb_path = self.emb_dir + "text-embedding-ada-002/query/query_emb_dict.pt"
|
22 |
+
self.query_emb_dict = torch.load(self.query_emb_path)
|
23 |
+
# print(f"777, {self.query_emb_path}")
|
24 |
+
|
25 |
+
self.candidate_emb_path = self.emb_dir + "text-embedding-ada-002/doc/candidate_emb_dict.pt"
|
26 |
+
self.candidate_emb_dict = torch.load(self.candidate_emb_path)
|
27 |
+
self.device = device
|
28 |
+
|
29 |
+
assert len(self.candidate_emb_dict) == len(self.candidate_ids)
|
30 |
+
|
31 |
+
candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
|
32 |
+
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
|
33 |
+
|
34 |
+
def score(self, query, q_id, candidate_ids):
|
35 |
+
"""
|
36 |
+
pred_dict[node_id] = similarity (tensor)
|
37 |
+
|
38 |
+
"""
|
39 |
+
# Dimension of query_emb: torch.Size([1, emb_dim])
|
40 |
+
query_emb = self.query_emb_dict[q_id].view(1, -1)
|
41 |
+
# Dimension of candidates_embs: torch.Size([num_candidates, emb_dim])
|
42 |
+
# # candidates_embs = self.candidate_embs[candidates_ids]
|
43 |
+
candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids]
|
44 |
+
candidates_embs = torch.cat(candi_embs, dim=0).to(self.device)
|
45 |
+
# Dimension of similarity: torch.Size([num_candidates])
|
46 |
+
similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu()
|
47 |
+
pred_dict = {}
|
48 |
+
for i in range(len(candidate_ids)):
|
49 |
+
pred_dict[candidate_ids[i]] = similarity[i].item()
|
50 |
+
|
51 |
+
return pred_dict
|
52 |
+
|
53 |
+
def retrieve(self, query, q_id, topk, node_type=None):
|
54 |
+
# Dimension of query_emb: torch.Size([1, emb_dim])
|
55 |
+
query_emb = self.query_emb_dict[q_id].view(1, -1)
|
56 |
+
|
57 |
+
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
|
58 |
+
if isinstance(query, str):
|
59 |
+
pred_dict = dict(zip(self.candidate_ids, similarity.view(-1)))
|
60 |
+
|
61 |
+
sorted_pred_ids = sorted(pred_dict, key=lambda x: pred_dict[x], reverse=True)
|
62 |
+
selected_pred_ids = sorted_pred_ids[:topk]
|
63 |
+
pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids}
|
64 |
+
print(f"sorted: {pred_dict}")
|
65 |
+
|
66 |
+
return pred_dict
|
Reasoning/text_retrievers/bm25.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
input: query, node_type, topk
|
3 |
+
output: pred_dict: {node_id: score}
|
4 |
+
|
5 |
+
"""
|
6 |
+
import bm25s
|
7 |
+
from tqdm import tqdm
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
# Get the absolute path of the current script
|
11 |
+
current_file = Path(__file__).resolve()
|
12 |
+
project_root = current_file.parents[2]
|
13 |
+
# Add the project root to the system path
|
14 |
+
sys.path.append(str(project_root))
|
15 |
+
|
16 |
+
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
|
17 |
+
|
18 |
+
target_type = {'amazon': 'product', 'prime': 'combine', 'mag': 'paper'}
|
19 |
+
|
20 |
+
class BM25(ModelForSTaRKQA):
|
21 |
+
|
22 |
+
def __init__(self, skb, dataset_name):
|
23 |
+
super(BM25, self).__init__(skb)
|
24 |
+
self.retrievers = {}
|
25 |
+
self.text_to_ids = {}
|
26 |
+
type_names = skb.node_type_lst()
|
27 |
+
self.nodeid_to_index = {}
|
28 |
+
|
29 |
+
self.target_type = target_type[dataset_name]
|
30 |
+
|
31 |
+
if self.target_type not in type_names:
|
32 |
+
ids = skb.get_candidate_ids()
|
33 |
+
|
34 |
+
|
35 |
+
corpus = [skb.get_doc_info(id) for id in tqdm(ids, desc=f"Gathering docs for combine")]
|
36 |
+
retriever = bm25s.BM25(corpus=corpus)
|
37 |
+
retriever.index(bm25s.tokenize(corpus))
|
38 |
+
# Build hash map from text to node_id
|
39 |
+
text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
|
40 |
+
# Store the retriever and text_to_id by type_name
|
41 |
+
self.retrievers[self.target_type] = retriever
|
42 |
+
self.text_to_ids[self.target_type] = text_to_id
|
43 |
+
|
44 |
+
self.nodeid_to_index[self.target_type] = {id: i for i, id in enumerate(ids)}
|
45 |
+
|
46 |
+
# Initialize retrievers and text-to-index maps for each type_name
|
47 |
+
for type_name in type_names:
|
48 |
+
ids = skb.get_node_ids_by_type(type_name)
|
49 |
+
|
50 |
+
# we manually replace '&' with '_and_' to avoid the error in BM25, because BM25 uses '&' as a special character and will not tokenize it
|
51 |
+
corpus = [skb.get_doc_info(id).replace('&', '_and_').replace('P.O.R', 'P_dot_O_dot_R') for id in tqdm(ids, desc=f"Gathering docs for {type_name}")]
|
52 |
+
# Create the BM25 model for the current type_name
|
53 |
+
retriever = bm25s.BM25(corpus=corpus)
|
54 |
+
retriever.index(bm25s.tokenize(corpus))
|
55 |
+
|
56 |
+
# Build hash map from text to index
|
57 |
+
text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
|
58 |
+
|
59 |
+
# Store the retriever and text_to_id by type_name
|
60 |
+
self.retrievers[type_name] = retriever
|
61 |
+
self.text_to_ids[type_name] = text_to_id
|
62 |
+
|
63 |
+
# build map from node_id to index
|
64 |
+
self.nodeid_to_index[type_name] = {id: i for i, id in enumerate(ids)}
|
65 |
+
|
66 |
+
def score(self, query, q_id, candidate_ids):
|
67 |
+
pred_dict = {}
|
68 |
+
|
69 |
+
for c_id in candidate_ids:
|
70 |
+
type_name = self.skb.get_node_type_by_id(c_id)
|
71 |
+
score = self.retrievers[type_name].get_scores(list(bm25s.tokenize(query)[1].keys()))[self.nodeid_to_index[type_name][c_id]] # save the query tokens
|
72 |
+
pred_dict[c_id] = score
|
73 |
+
|
74 |
+
# print(f"999, {pred_dict}")
|
75 |
+
|
76 |
+
return pred_dict
|
77 |
+
|
78 |
+
def retrieve(self, query, q_id, topk, node_type=None):
|
79 |
+
|
80 |
+
"""
|
81 |
+
Forward pass to compute similarity scores for the given query.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
query (str): Query string.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
|
88 |
+
"""
|
89 |
+
if '&' in query:
|
90 |
+
query = query.replace('&', '_and_')
|
91 |
+
if 'P.O.R' in query:
|
92 |
+
query = query.replace('P.O.R', 'P_dot_O_dot_R')
|
93 |
+
if isinstance(node_type, list):
|
94 |
+
if len(node_type) > 1:
|
95 |
+
node_type = 'combine'
|
96 |
+
else:
|
97 |
+
node_type = node_type[0]
|
98 |
+
results, scores = self.retrievers[node_type].retrieve(bm25s.tokenize(query), k=topk)
|
99 |
+
ids = [self.text_to_ids[node_type][hash(result.item())] for result in results[0]]
|
100 |
+
scores = scores[0].tolist()
|
101 |
+
pred_dict = dict(zip(ids, scores))
|
102 |
+
# print(f"666, {pred_dict}")
|
103 |
+
|
104 |
+
return pred_dict
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
print(f"Testing BM25")
|
108 |
+
|
Reasoning/text_retrievers/contriever.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
input: query, query_id, candidates_ids
|
4 |
+
output: pred_dict: {node_id: similarity}
|
5 |
+
"""
|
6 |
+
import heapq
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
# Get the absolute path of the current script
|
10 |
+
current_file = Path(__file__).resolve()
|
11 |
+
project_root = current_file.parents[2]
|
12 |
+
# Add the project root to the system path
|
13 |
+
sys.path.append(str(project_root))
|
14 |
+
from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings
|
15 |
+
import torch
|
16 |
+
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
|
17 |
+
|
18 |
+
class Contriever(ModelForSTaRKQA):
|
19 |
+
def __init__(self, skb, dataset_name, device):
|
20 |
+
super(Contriever, self).__init__(skb)
|
21 |
+
self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/"
|
22 |
+
|
23 |
+
self.query_emb_path = self.emb_dir + "contriever/query_no_rel_no_compact/query_emb_dict.pt"
|
24 |
+
self.query_emb_dict = torch.load(self.query_emb_path)
|
25 |
+
|
26 |
+
self.candidate_emb_path = self.emb_dir + "contriever/doc_no_rel_no_compact/candidate_emb_dict.pt"
|
27 |
+
self.candidate_emb_dict = torch.load(self.candidate_emb_path)
|
28 |
+
self.device = device
|
29 |
+
|
30 |
+
assert len(self.candidate_emb_dict) == len(self.candidate_ids)
|
31 |
+
|
32 |
+
candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
|
33 |
+
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
|
34 |
+
|
35 |
+
# load contriever for query embeddings
|
36 |
+
self.encoder, self.tokenizer = get_contriever(dataset_name=dataset_name)
|
37 |
+
self.encoder = self.encoder.to(device)
|
38 |
+
|
39 |
+
|
40 |
+
def score(self, query, q_id, candidate_ids):
|
41 |
+
"""
|
42 |
+
pred_dict[node_id] = similarity (tensor)
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
# Dimension of query_emb: torch.Size([1, emb_dim])
|
48 |
+
query_emb = self.query_emb_dict[q_id].view(1, -1)
|
49 |
+
|
50 |
+
# Dimension of candidates_embs: torch.Size([num_candidates, emb_dim])
|
51 |
+
candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids]
|
52 |
+
candidates_embs = torch.cat(candi_embs, dim=0).to(self.device)
|
53 |
+
# Dimension of similarity: torch.Size([num_candidates])
|
54 |
+
similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu()
|
55 |
+
pred_dict = {}
|
56 |
+
for i in range(len(candidate_ids)):
|
57 |
+
pred_dict[candidate_ids[i]] = similarity[i].item()
|
58 |
+
|
59 |
+
return pred_dict
|
60 |
+
|
61 |
+
def retrieve(self, query, q_id, topk, node_type=None):
|
62 |
+
# Dimension of query_emb: torch.Size([1, emb_dim])
|
63 |
+
query_emb = get_contriever_embeddings(query, encoder=self.encoder, tokenizer=self.tokenizer, device=self.device)
|
64 |
+
|
65 |
+
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
|
66 |
+
|
67 |
+
if isinstance(query, str):
|
68 |
+
pred_dict = dict(zip(self.candidate_ids, similarity.view(-1)))
|
69 |
+
|
70 |
+
|
71 |
+
selected_pred_ids = heapq.nlargest(topk, pred_dict, key=pred_dict.get)
|
72 |
+
pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids}
|
73 |
+
|
74 |
+
|
75 |
+
return pred_dict
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
print("Testing Contriever")
|
Reasoning/text_retrievers/stark_model.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
from typing import Any, Union, List, Dict
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
# Get the absolute path of the current script
|
10 |
+
current_file = Path(__file__).resolve()
|
11 |
+
project_root = current_file.parents[2]
|
12 |
+
# Add the project root to the system path
|
13 |
+
sys.path.append(str(project_root))
|
14 |
+
|
15 |
+
from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings
|
16 |
+
from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
|
17 |
+
from stark_qa.evaluator import Evaluator
|
18 |
+
|
19 |
+
|
20 |
+
class ModelForSTaRKQA(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, skb, query_emb_dir='.'):
|
23 |
+
"""
|
24 |
+
Initializes the model with the given knowledge base.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
skb: Knowledge base containing candidate information.
|
28 |
+
"""
|
29 |
+
super(ModelForSTaRKQA, self).__init__()
|
30 |
+
self.skb = skb
|
31 |
+
|
32 |
+
self.candidate_ids = skb.candidate_ids
|
33 |
+
self.num_candidates = skb.num_candidates
|
34 |
+
self.query_emb_dir = query_emb_dir
|
35 |
+
|
36 |
+
query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
|
37 |
+
if os.path.exists(query_emb_path):
|
38 |
+
print(f'Load query embeddings from {query_emb_path}')
|
39 |
+
self.query_emb_dict = torch.load(query_emb_path)
|
40 |
+
else:
|
41 |
+
self.query_emb_dict = {}
|
42 |
+
self.evaluator = Evaluator(self.candidate_ids)
|
43 |
+
|
44 |
+
def forward(self,
|
45 |
+
query: Union[str, List[str]],
|
46 |
+
candidates: List[int] = None,
|
47 |
+
query_id: Union[int, List[int]] = None,
|
48 |
+
**kwargs: Any) -> Dict[str, Any]:
|
49 |
+
"""
|
50 |
+
Forward pass to compute predictions for the given query.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
query (Union[str, list]): Query string or a list of query strings.
|
54 |
+
candidates (Union[list, None]): A list of candidate ids (optional).
|
55 |
+
query_id (Union[int, list, None]): Query index (optional).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
pred_dict (dict): A dictionary of predicted scores or answer ids.
|
59 |
+
"""
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def get_query_emb(self,
|
63 |
+
query: Union[str, List[str]],
|
64 |
+
query_id: Union[int, List[int]],
|
65 |
+
emb_model: str = 'text-embedding-ada-002',
|
66 |
+
**encode_kwargs) -> torch.Tensor:
|
67 |
+
"""
|
68 |
+
Retrieves or computes the embedding for the given query.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
query (str): Query string.
|
72 |
+
query_id (int): Query index.
|
73 |
+
emb_model (str): Embedding model to use.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
query_emb (torch.Tensor): Query embedding.
|
77 |
+
"""
|
78 |
+
if isinstance(query_id, int):
|
79 |
+
query_id = [query_id]
|
80 |
+
if isinstance(query, str):
|
81 |
+
query = [query]
|
82 |
+
|
83 |
+
encode_kwargs['is_query'] = True
|
84 |
+
if query_id is None:
|
85 |
+
query_emb = get_embeddings(query, emb_model, **encode_kwargs)
|
86 |
+
elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))):
|
87 |
+
query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0)
|
88 |
+
else:
|
89 |
+
query_emb = get_embeddings(query, emb_model, **encode_kwargs)
|
90 |
+
for qid, emb in zip(query_id, query_emb):
|
91 |
+
self.query_emb_dict[qid] = emb.view(1, -1)
|
92 |
+
torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt'))
|
93 |
+
|
94 |
+
query_emb = query_emb.view(len(query), -1)
|
95 |
+
return query_emb
|
96 |
+
|
97 |
+
def evaluate(self,
|
98 |
+
pred_dict: Dict[int, float],
|
99 |
+
answer_ids: Union[torch.LongTensor, List[Any]],
|
100 |
+
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
|
101 |
+
**kwargs: Any) -> Dict[str, float]:
|
102 |
+
"""
|
103 |
+
Evaluates the predictions using the specified metrics.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
pred_dict (Dict[int, float]): Predicted answer ids or scores.
|
107 |
+
answer_ids (torch.LongTensor): Ground truth answer ids.
|
108 |
+
metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
|
109 |
+
'precision@k', 'map@k', 'ndcg@k'.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Dict[str, float]: A dictionary of evaluation metrics.
|
113 |
+
"""
|
114 |
+
return self.evaluator(pred_dict, answer_ids, metrics)
|
115 |
+
|
116 |
+
def evaluate_batch(self,
|
117 |
+
pred_ids: List[int],
|
118 |
+
pred: torch.Tensor,
|
119 |
+
answer_ids: Union[torch.LongTensor, List[Any]],
|
120 |
+
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
|
121 |
+
**kwargs: Any) -> Dict[str, float]:
|
122 |
+
return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
|
123 |
+
|
124 |
+
|
125 |
+
def get_embeddings(text, model_name, **encode_kwargs):
|
126 |
+
"""
|
127 |
+
Get embeddings for the given text using the specified model.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
model_name (str): Model name.
|
131 |
+
text (Union[str, List[str]]): The input text to be embedded.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
torch.Tensor: Embedding of the input text.
|
135 |
+
"""
|
136 |
+
if isinstance(text, str):
|
137 |
+
text = [text]
|
138 |
+
|
139 |
+
if 'GritLM' in model_name:
|
140 |
+
emb = get_gritlm_embeddings(text, model_name, **encode_kwargs)
|
141 |
+
elif 'LLM2Vec' in model_name:
|
142 |
+
emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs)
|
143 |
+
elif 'all-mpnet-base-v2' in model_name or 'dunzhang/stella_en_1.5B_v5' in model_name:
|
144 |
+
emb = get_sentence_transformer_embeddings(text, model_name, **encode_kwargs)
|
145 |
+
else:
|
146 |
+
emb = get_api_embeddings(text, model_name, **encode_kwargs)
|
147 |
+
return emb.view(len(text), -1)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
print("Testing ModelForSTaRKQA...")
|
Reasoning/utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Reasoning.text_retrievers.bm25 import BM25
|
2 |
+
from Reasoning.text_retrievers.ada import Ada
|
3 |
+
from Reasoning.text_retrievers.contriever import Contriever
|
4 |
+
|
5 |
+
|
6 |
+
def combine_dicts(dicts_list, pred_dict):
|
7 |
+
if len(dicts_list) == 1:
|
8 |
+
return dicts_list[0]
|
9 |
+
combined_dict = {}
|
10 |
+
|
11 |
+
for d in dicts_list:
|
12 |
+
for key, value in d.items():
|
13 |
+
if key in combined_dict:
|
14 |
+
# for route dict, the values are lists, keep the longest list
|
15 |
+
if len(value) > len(combined_dict[key]):
|
16 |
+
combined_dict[key] = value
|
17 |
+
else:
|
18 |
+
combined_dict[key] = value
|
19 |
+
|
20 |
+
# if the two reasoning paths have intersection, only keep the keys in pred_dict
|
21 |
+
combined_dict = {key: combined_dict[key] for key in pred_dict.keys()}
|
22 |
+
|
23 |
+
|
24 |
+
return combined_dict
|
25 |
+
|
26 |
+
def fix_length(paths_dict):
|
27 |
+
max_length = 3
|
28 |
+
new_paths_dict = {}
|
29 |
+
|
30 |
+
for key, value in paths_dict.items():
|
31 |
+
if len(value) > max_length:
|
32 |
+
value = value[-max_length:]
|
33 |
+
if len(value) < max_length:
|
34 |
+
# padding with -1 at the beginning
|
35 |
+
value = [-1] * (max_length - len(value)) + value
|
36 |
+
new_paths_dict[key] = value
|
37 |
+
|
38 |
+
return new_paths_dict
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def parse_metapath(metapath):
|
43 |
+
"""
|
44 |
+
input: metapath: "paper -> author -> paper <- paper"
|
45 |
+
output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
|
46 |
+
"""
|
47 |
+
|
48 |
+
def parse(remain_list, direction):
|
49 |
+
"""
|
50 |
+
input: remain_list: ["paper", "->", "author", "->", "paper", "<-", "paper"]
|
51 |
+
direction: "->"
|
52 |
+
output: route: ["paper", "author", "paper"]
|
53 |
+
remain_list: ["paper", "<-", "paper"]
|
54 |
+
"""
|
55 |
+
route = []
|
56 |
+
i = 0
|
57 |
+
while i < len(remain_list)-1 and remain_list[i+1] == direction:
|
58 |
+
route.append(remain_list[i])
|
59 |
+
i += 2
|
60 |
+
route.append(remain_list[i])
|
61 |
+
|
62 |
+
if direction == "<-":
|
63 |
+
route.reverse()
|
64 |
+
|
65 |
+
remain_list = None if len(remain_list) == i+1 else remain_list[i:]
|
66 |
+
|
67 |
+
return route, remain_list
|
68 |
+
|
69 |
+
|
70 |
+
remain_list = metapath.split(' ')
|
71 |
+
# print(f"111, {remain_list}")
|
72 |
+
|
73 |
+
if len(remain_list) == 1: # single node
|
74 |
+
return [remain_list]
|
75 |
+
|
76 |
+
routes = []
|
77 |
+
while remain_list is not None:
|
78 |
+
if remain_list[1] == "<-":
|
79 |
+
route, remain_list = parse(remain_list, "<-")
|
80 |
+
|
81 |
+
elif remain_list[1] == "->":
|
82 |
+
route, remain_list = parse(remain_list, "->")
|
83 |
+
|
84 |
+
else:
|
85 |
+
# raise ValueError(f"Invalid metapath: {metapath}")
|
86 |
+
return None
|
87 |
+
|
88 |
+
routes.append(route)
|
89 |
+
|
90 |
+
return routes
|
91 |
+
|
92 |
+
|
93 |
+
def get_text_retriever(dataset_name, retriever_name, skb, **kwargs):
|
94 |
+
if retriever_name == "bm25":
|
95 |
+
return BM25(skb, dataset_name)
|
96 |
+
elif retriever_name == "ada":
|
97 |
+
return Ada(skb, dataset_name, kwargs.get("device", 'cuda'))
|
98 |
+
elif retriever_name == "contriever":
|
99 |
+
return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Invalid retriever name: {retriever_name}")
|
102 |
+
|
103 |
+
|
104 |
+
def get_scorer(dataset_name, scorer_name, skb, **kwargs):
|
105 |
+
if scorer_name == "bm25":
|
106 |
+
return BM25(skb, dataset_name)
|
107 |
+
elif scorer_name == "ada":
|
108 |
+
return Ada(skb, dataset_name, kwargs.get("device",'cuda'))
|
109 |
+
elif scorer_name == "contriever":
|
110 |
+
return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
|
111 |
+
else:
|
112 |
+
raise ValueError(f"Invalid scorer name: {scorer_name}")
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
print(f"Test utils")
|
Reranking/__pycache__/rerank.cpython-311.pyc
ADDED
Binary file (18 kB). View file
|
|
Reranking/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (5.23 kB). View file
|
|
Reranking/data/checkpoints/amazon/best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4ef56bf1095b28283a5bc451556e238b31d36285df39b1a84b9c5c65f1900a2
|
3 |
+
size 804607
|
Reranking/data/checkpoints/mag/best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b25c04fcb5fa1210fe8ca10bb7c1b3f3aad592ec0e501a1ccc330b92772550f1
|
3 |
+
size 804620
|
Reranking/data/checkpoints/prime/best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1e4e24531067e798b975e33689cc87645a2daa9d736df1e8ac0474f72154ce0
|
3 |
+
size 804633
|
Reranking/rerank.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
|
4 |
+
from stark_qa import load_skb
|
5 |
+
|
6 |
+
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from Reranking.utils import move_to_cuda, seed_everything
|
14 |
+
from Reranking.rerankers.path import PathReranker
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import argparse
|
17 |
+
import pickle as pkl
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class TestDataset(Dataset):
|
23 |
+
"""
|
24 |
+
data format: {
|
25 |
+
"query": query,
|
26 |
+
"pred_dict": {node_id: score},
|
27 |
+
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
|
28 |
+
"text_emb_dict": {node_id: text_emb},
|
29 |
+
"ans_ids": [],
|
30 |
+
}
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, saved_data, args):
|
35 |
+
|
36 |
+
print(f"Start processing test dataset...")
|
37 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
38 |
+
self.data = saved_data['data']
|
39 |
+
|
40 |
+
self.text_emb_matrix = list(self.text2emb_dict.values())
|
41 |
+
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
|
42 |
+
|
43 |
+
# make the mapping between the key of text2emb_dict and the index of text_emb_matrix
|
44 |
+
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
|
45 |
+
|
46 |
+
self.args = args
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
|
56 |
+
if self.args.dataset_name == 'amazon':
|
57 |
+
# change from the str to index
|
58 |
+
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
|
59 |
+
else:
|
60 |
+
# sort the pred_dict by the score
|
61 |
+
pred_dict = self.data[idx]['pred_dict']
|
62 |
+
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
|
63 |
+
# get the top 50 candidates
|
64 |
+
sorted_ids = sorted_ids[:50]
|
65 |
+
# get the score vector
|
66 |
+
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
|
67 |
+
# get the symb_enc_dict
|
68 |
+
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
|
69 |
+
# change from the str to index
|
70 |
+
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
|
71 |
+
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
return self.data[idx]
|
76 |
+
|
77 |
+
|
78 |
+
def collate_batch(self, batch):
|
79 |
+
|
80 |
+
# q
|
81 |
+
batch_q = [batch[i]['query'] for i in range(len(batch))]
|
82 |
+
q_text = batch_q
|
83 |
+
|
84 |
+
# c
|
85 |
+
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
|
86 |
+
batch_c = torch.tensor(batch_c)
|
87 |
+
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
|
88 |
+
c_score_vector = torch.tensor(c_score_vector)
|
89 |
+
c_score_vector = c_score_vector[:, :, :self.args.vector_dim]
|
90 |
+
|
91 |
+
# c_symb_enc
|
92 |
+
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
|
93 |
+
c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
|
94 |
+
|
95 |
+
# c_text_emb
|
96 |
+
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
|
97 |
+
c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
|
98 |
+
|
99 |
+
|
100 |
+
# ans_ids
|
101 |
+
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
|
102 |
+
|
103 |
+
# pred_ids
|
104 |
+
pred_ids = batch_c.tolist()
|
105 |
+
|
106 |
+
|
107 |
+
# Create a dictionary for the batch
|
108 |
+
feed_dict = {
|
109 |
+
'query': q_text,
|
110 |
+
'c_score_vector': c_score_vector,
|
111 |
+
'c_text_emb': c_text_emb,
|
112 |
+
'c_symb_enc': c_symb_enc,
|
113 |
+
'ans_ids': ans_ids,
|
114 |
+
'pred_ids': pred_ids
|
115 |
+
|
116 |
+
}
|
117 |
+
|
118 |
+
|
119 |
+
return feed_dict
|
120 |
+
|
121 |
+
|
122 |
+
# ***** batch_evaluator *****
|
123 |
+
def batch_evaluator(skb, scores_cand, ans_ids, batch):
|
124 |
+
|
125 |
+
results = {}
|
126 |
+
|
127 |
+
# **** batch wise evaluation ****
|
128 |
+
# evaluate
|
129 |
+
candidates_ids = skb.candidate_ids
|
130 |
+
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
|
131 |
+
|
132 |
+
|
133 |
+
# initialize the pred_matrix
|
134 |
+
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
|
135 |
+
|
136 |
+
|
137 |
+
# get the index of each pred_ids
|
138 |
+
# flatten the pred_ids
|
139 |
+
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
|
140 |
+
|
141 |
+
|
142 |
+
# get the index of each pred_ids
|
143 |
+
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
|
144 |
+
|
145 |
+
|
146 |
+
# reshape the pred_idx
|
147 |
+
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
|
148 |
+
|
149 |
+
# move pred_matrix to the device
|
150 |
+
pred_matrix = pred_matrix.to(scores_cand.device)
|
151 |
+
|
152 |
+
# advanced indexing
|
153 |
+
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
|
154 |
+
|
155 |
+
|
156 |
+
# Create a mapping from candidate IDs to their indices for faster lookup
|
157 |
+
|
158 |
+
|
159 |
+
# Flatten ans_ids to a single list and map them to indices
|
160 |
+
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
|
161 |
+
|
162 |
+
# Create the row indices for ans_matrix corresponding to the answers
|
163 |
+
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
|
164 |
+
|
165 |
+
# Create the answer matrix
|
166 |
+
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
|
167 |
+
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
# batch computing hit1
|
172 |
+
# find the index of the max score
|
173 |
+
max_score, max_idx = torch.max(pred_matrix, dim=1)
|
174 |
+
# check the label of the max idx
|
175 |
+
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
|
176 |
+
hit1_list = batch_hit1.tolist()
|
177 |
+
|
178 |
+
|
179 |
+
# batch computing hit@5
|
180 |
+
_, top5_idx = torch.topk(pred_matrix, 5, dim=1)
|
181 |
+
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
|
182 |
+
|
183 |
+
# max with each row
|
184 |
+
batch_hit5 = torch.max(batch_hit5, dim=1)[0]
|
185 |
+
hit5_list = batch_hit5.tolist()
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
# batch computing recall@20
|
190 |
+
_, top20_idx = torch.topk(pred_matrix, 20, dim=1)
|
191 |
+
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
|
192 |
+
# sum with each row
|
193 |
+
batch_recall20 = torch.sum(batch_recall20, dim=1)
|
194 |
+
# divide by the sum of the ans_matrix along the row
|
195 |
+
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
|
196 |
+
recall20_list = batch_recall20.tolist()
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
# batch computing mrr
|
201 |
+
# find the highest rank of the answer
|
202 |
+
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
|
203 |
+
# query the answer matrix with the rank_idx
|
204 |
+
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
|
205 |
+
# find the first rank of the answer
|
206 |
+
batch_mrr = torch.argmax(batch_mrr, dim=1)
|
207 |
+
# add 1 to the rank
|
208 |
+
batch_mrr += 1
|
209 |
+
# divide by the rank
|
210 |
+
batch_mrr = 1 / batch_mrr.float()
|
211 |
+
mrr_list = batch_mrr.tolist()
|
212 |
+
|
213 |
+
|
214 |
+
results['hit@1'] = hit1_list
|
215 |
+
results['hit@5'] = hit5_list
|
216 |
+
results['recall@20'] = recall20_list
|
217 |
+
results['mrr'] = mrr_list
|
218 |
+
|
219 |
+
|
220 |
+
return results
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
# ***** evaluate *****
|
225 |
+
@torch.no_grad()
|
226 |
+
def evaluate(router, test_loader, skb):
|
227 |
+
|
228 |
+
|
229 |
+
router.eval()
|
230 |
+
|
231 |
+
all_results = {
|
232 |
+
"hit@1": [],
|
233 |
+
"hit@5": [],
|
234 |
+
"recall@20": [],
|
235 |
+
"mrr": []
|
236 |
+
}
|
237 |
+
avg_results = {
|
238 |
+
"hit@1": 0,
|
239 |
+
"hit@5": 0,
|
240 |
+
"recall@20": 0,
|
241 |
+
"mrr": 0
|
242 |
+
}
|
243 |
+
|
244 |
+
|
245 |
+
# save the scores and ans_ids, and pred_ids
|
246 |
+
pred_list = []
|
247 |
+
scores_cand_list = []
|
248 |
+
ans_ids_list = []
|
249 |
+
print(f"Start evaluating...")
|
250 |
+
# use tqdm to show the progress
|
251 |
+
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
|
252 |
+
# print(f"idx: {idx}")
|
253 |
+
batch = move_to_cuda(batch)
|
254 |
+
|
255 |
+
# Check if the model is wrapped in DataParallel
|
256 |
+
if isinstance(router, nn.DataParallel):
|
257 |
+
scores_cand = router.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
|
258 |
+
else:
|
259 |
+
scores_cand = router.eval_batch(batch)
|
260 |
+
|
261 |
+
|
262 |
+
# ans_ids
|
263 |
+
ans_ids = batch['ans_ids']
|
264 |
+
|
265 |
+
results = batch_evaluator(skb, scores_cand, ans_ids, batch)
|
266 |
+
|
267 |
+
|
268 |
+
for key in results.keys():
|
269 |
+
all_results[key].extend(results[key])
|
270 |
+
|
271 |
+
# save the scores and ans_ids, and pred_ids
|
272 |
+
pred_list.extend(batch['pred_ids'])
|
273 |
+
scores_cand_list.extend(scores_cand.cpu().tolist())
|
274 |
+
ans_ids_list.extend(ans_ids)
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
for key in avg_results.keys():
|
279 |
+
avg_results[key] = np.mean(all_results[key])
|
280 |
+
|
281 |
+
print(f"Results: {avg_results}")
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
return avg_results
|
286 |
+
|
287 |
+
|
288 |
+
def parse_args():
|
289 |
+
|
290 |
+
parser = argparse.ArgumentParser(description="Run PathRouter with dynamic combinations of embeddings.")
|
291 |
+
|
292 |
+
# dataset_name
|
293 |
+
parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset.")
|
294 |
+
|
295 |
+
# Add arguments for model configurations
|
296 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
|
297 |
+
|
298 |
+
|
299 |
+
# add concat_num
|
300 |
+
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
|
301 |
+
|
302 |
+
# checkpoint save path
|
303 |
+
parser.add_argument("--checkpoint_path", type=str, default="./data/checkpoints", help="Path saves the checkpoints.")
|
304 |
+
|
305 |
+
# similarity vector dim
|
306 |
+
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
|
307 |
+
|
308 |
+
|
309 |
+
# Parse the base arguments
|
310 |
+
args = parser.parse_args()
|
311 |
+
return args
|
312 |
+
|
313 |
+
|
314 |
+
def get_concat_num(combo):
|
315 |
+
"""
|
316 |
+
Determine the value of concat_num based on the combination of embeddings.
|
317 |
+
- score_vec adds +1
|
318 |
+
- text_emb adds +1
|
319 |
+
- symb_enc adds +3
|
320 |
+
"""
|
321 |
+
concat_num = 0
|
322 |
+
if combo.get("score_vec", False): # If score_vec is True
|
323 |
+
concat_num += 1
|
324 |
+
if combo.get("text_emb", False): # If text_emb is True
|
325 |
+
concat_num += 1
|
326 |
+
if combo.get("symb_enc", False): # If symb_enc is True
|
327 |
+
concat_num += 3
|
328 |
+
|
329 |
+
|
330 |
+
return concat_num
|
331 |
+
|
332 |
+
|
333 |
+
def run(test_data, skb, dataset_name):
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
test_size = 64
|
338 |
+
test_dataset = TestDataset(test_data, args=args)
|
339 |
+
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
|
340 |
+
|
341 |
+
# load the model
|
342 |
+
print(f"Load the model...")
|
343 |
+
args.checkpoint_path = args.checkpoint_path + f"/{dataset_name}/best.pth"
|
344 |
+
router = PathReranker(socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, args=args)
|
345 |
+
checkpoint = torch.load(args.checkpoint_path)
|
346 |
+
router.load_state_dict(checkpoint)
|
347 |
+
router = router.to(args.device)
|
348 |
+
|
349 |
+
# evalute
|
350 |
+
test_results = evaluate(router, test_loader, skb)
|
351 |
+
print(f"Test evaluation")
|
352 |
+
print(test_results)
|
353 |
+
|
354 |
+
return test_results
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
|
358 |
+
combo = {
|
359 |
+
"text_emb": True,
|
360 |
+
"score_vec": True,
|
361 |
+
"symb_enc": True
|
362 |
+
}
|
363 |
+
concat_num = get_concat_num(combo)
|
364 |
+
|
365 |
+
base_args = parse_args()
|
366 |
+
args = argparse.Namespace(**vars(base_args), **combo)
|
367 |
+
args.concat_num = concat_num
|
368 |
+
dataset_name = args.dataset_name
|
369 |
+
|
370 |
+
test_data_path = f"../{dataset_name}_test.pkl"
|
371 |
+
with open(test_data_path, 'rb') as f:
|
372 |
+
test_data = pkl.load(f)
|
373 |
+
skb = load_skb(dataset_name)
|
374 |
+
results = run(test_data, skb, dataset_name)
|
375 |
+
|
Reranking/rerankers/__pycache__/node.cpython-311.pyc
ADDED
Binary file (5.38 kB). View file
|
|
Reranking/rerankers/__pycache__/path.cpython-311.pyc
ADDED
Binary file (5.43 kB). View file
|
|
Reranking/rerankers/node.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
# ***** Reranking Model *****
|
6 |
+
class Contriever(nn.Module):
|
7 |
+
def __init__(self, encoder, emb_dim=768):
|
8 |
+
super(Contriever, self).__init__()
|
9 |
+
self.encoder = encoder
|
10 |
+
self.emb_dim = emb_dim
|
11 |
+
|
12 |
+
def mean_pooling(self, token_embs, mask):
|
13 |
+
token_embs = token_embs.masked_fill(~mask[..., None].bool(), 0.0)
|
14 |
+
sentence_embeddings = token_embs.sum(dim=1) / (mask.sum(dim=1)[..., None].clamp(min=1e-9))
|
15 |
+
return sentence_embeddings
|
16 |
+
|
17 |
+
def encode_seq(self, input_ids, attention_mask, token_type_ids=None):
|
18 |
+
# Combine inputs into a dictionary
|
19 |
+
enc = {'input_ids': input_ids, 'attention_mask': attention_mask}
|
20 |
+
if token_type_ids is not None:
|
21 |
+
enc['token_type_ids'] = token_type_ids
|
22 |
+
|
23 |
+
outputs = self.encoder(**enc)
|
24 |
+
# Mean pooling of last hidden states
|
25 |
+
embedded = self.mean_pooling(outputs[0], attention_mask)
|
26 |
+
# print(f"777, {embedded.shape}")
|
27 |
+
return embedded
|
28 |
+
|
29 |
+
def get_text_emb(self, input_ids, attention_mask, token_type_ids):
|
30 |
+
emb = self.encode_seq(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
31 |
+
|
32 |
+
return emb
|
33 |
+
|
34 |
+
def eval_batch(self, batch):
|
35 |
+
# q_emb: [batch_size/num_gpus, token_dim]
|
36 |
+
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
|
37 |
+
|
38 |
+
# c_emb: [batch_size * num_candidates/num_gpus, token_dim]
|
39 |
+
c_emb = self.get_text_emb(batch['c_enc_input_ids'], batch['c_enc_attention_mask'], batch['c_enc_token_type_ids'])
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
return q_emb, c_emb
|
44 |
+
|
45 |
+
def forward(self, batch):
|
46 |
+
|
47 |
+
# q_emb: [batch_size/num_gpus, token_dim]
|
48 |
+
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
|
49 |
+
# p_emb: [batch_size*max_len/num_gpus, token_dim]
|
50 |
+
p_emb = self.get_text_emb(batch['pos_enc_input_ids'], batch['pos_enc_attention_mask'], batch['pos_enc_token_type_ids'])
|
51 |
+
# n_emb: [batch_size*max_len*num_sampled_negs/num_gpus, token_dim]
|
52 |
+
n_emb = self.get_text_emb(batch['neg_enc_input_ids'], batch['neg_enc_attention_mask'], batch['neg_enc_token_type_ids'])
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
return q_emb, p_emb, n_emb
|
57 |
+
|
58 |
+
|
59 |
+
# ***** Reranking Model *****
|
60 |
+
class NodeRouter(nn.Module):
|
61 |
+
def __init__(self, input_dim=2, output_dim=1, emb_dim=128):
|
62 |
+
super(NodeRouter, self).__init__()
|
63 |
+
self.fc1 = nn.Linear(input_dim, emb_dim)
|
64 |
+
self.fc2 = nn.Linear(emb_dim, output_dim)
|
65 |
+
self.relu = nn.ReLU()
|
66 |
+
|
67 |
+
def eval_batch(self, batch):
|
68 |
+
scores_cand = self.fc1(batch['c_scores'])
|
69 |
+
scores_cand = self.relu(scores_cand)
|
70 |
+
scores_cand = self.fc2(scores_cand)
|
71 |
+
scores_cand = self.relu(scores_cand)
|
72 |
+
|
73 |
+
return scores_cand
|
74 |
+
|
75 |
+
def forward(self, batch):
|
76 |
+
scores_pos = self.fc1(batch['p_scores'])
|
77 |
+
scores_neg = self.fc1(batch['n_scores'])
|
78 |
+
scores_pos = self.relu(scores_pos)
|
79 |
+
scores_neg = self.relu(scores_neg)
|
80 |
+
|
81 |
+
scores_pos = self.fc2(scores_pos)
|
82 |
+
scores_neg = self.fc2(scores_neg)
|
83 |
+
scores_pos = self.relu(scores_pos)
|
84 |
+
scores_neg = self.relu(scores_neg)
|
85 |
+
|
86 |
+
return scores_pos, scores_neg
|
87 |
+
|
Reranking/rerankers/path.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# ***** Reranking Model *****
|
5 |
+
# two linear layers
|
6 |
+
class PathReranker(nn.Module):
|
7 |
+
def __init__(self, socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, output_dim=1, emb_dim=256, args=None):
|
8 |
+
super(PathReranker, self).__init__()
|
9 |
+
self.score_vec_enc = nn.Linear(socre_vector_input_dim, emb_dim)
|
10 |
+
self.text_emb_enc = nn.Linear(text_emb_input_dim, emb_dim)
|
11 |
+
self.symb_enc = nn.Embedding(symb_enc_dim, emb_dim)
|
12 |
+
self.fc1 = nn.Linear(emb_dim*args.concat_num, output_dim)
|
13 |
+
self.fc2 = nn.Linear(output_dim, 1)
|
14 |
+
self.relu = nn.ReLU()
|
15 |
+
self.args = args
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def eval_batch(self, batch):
|
20 |
+
|
21 |
+
embeddings = []
|
22 |
+
|
23 |
+
if self.args.text_emb:
|
24 |
+
# Encode the text embedding and apply ReLU
|
25 |
+
text_emb_c = self.relu(self.text_emb_enc(batch['c_text_emb'])) # [bs, 100, emb_dim]
|
26 |
+
embeddings.append(text_emb_c)
|
27 |
+
|
28 |
+
|
29 |
+
if self.args.score_vec:
|
30 |
+
# Encode the score vector and apply ReLU
|
31 |
+
score_vector_c = self.relu(self.score_vec_enc(batch['c_score_vector'])) # [bs, 100, emb_dim]
|
32 |
+
|
33 |
+
embeddings.append(score_vector_c)
|
34 |
+
|
35 |
+
|
36 |
+
if self.args.symb_enc:
|
37 |
+
# encode the symbolic embedding and apply ReLU
|
38 |
+
symb_enc_c = self.relu(self.symb_enc(batch['c_symb_enc']))
|
39 |
+
# reshape the symbolic embedding
|
40 |
+
symb_enc_c = torch.reshape(symb_enc_c, (symb_enc_c.shape[0], symb_enc_c.shape[1], -1))
|
41 |
+
embeddings.append(symb_enc_c)
|
42 |
+
|
43 |
+
|
44 |
+
if len(embeddings) > 1:
|
45 |
+
emb_c = torch.cat(embeddings, dim=-1)
|
46 |
+
else:
|
47 |
+
emb_c = embeddings[0]
|
48 |
+
|
49 |
+
|
50 |
+
# Feed the concatenated embeddings to the final layer
|
51 |
+
emb_c = self.fc1(emb_c) # [bs, 100, emb_dim]
|
52 |
+
scores_c = self.fc2(emb_c) # [bs, 100, 1]
|
53 |
+
|
54 |
+
return scores_c
|
55 |
+
|
56 |
+
|
57 |
+
def forward(self, batch):
|
58 |
+
|
59 |
+
embeddings_pos = []
|
60 |
+
embeddings_neg = []
|
61 |
+
|
62 |
+
if self.args.text_emb:
|
63 |
+
# Encode the text embedding and apply ReLU
|
64 |
+
text_emb_pos = self.relu(self.text_emb_enc(batch['p_text_emb']))
|
65 |
+
text_emb_neg = self.relu(self.text_emb_enc(batch['n_text_emb']))
|
66 |
+
embeddings_pos.append(text_emb_pos)
|
67 |
+
embeddings_neg.append(text_emb_neg)
|
68 |
+
|
69 |
+
if self.args.score_vec:
|
70 |
+
# Encode the score vector and apply ReLU
|
71 |
+
score_vector_pos = self.relu(self.score_vec_enc(batch['p_score_vector']))
|
72 |
+
score_vector_neg = self.relu(self.score_vec_enc(batch['n_score_vector']))
|
73 |
+
embeddings_pos.append(score_vector_pos)
|
74 |
+
embeddings_neg.append(score_vector_neg)
|
75 |
+
|
76 |
+
|
77 |
+
if self.args.symb_enc:
|
78 |
+
# encode the symbolic embedding and apply ReLU
|
79 |
+
symb_enc_pos = self.relu(self.symb_enc(batch['p_symb_enc']))
|
80 |
+
# reshape the symbolic embedding
|
81 |
+
symb_enc_pos = torch.reshape(symb_enc_pos, (symb_enc_pos.shape[0], -1))
|
82 |
+
|
83 |
+
symb_enc_neg = self.relu(self.symb_enc(batch['n_symb_enc']))
|
84 |
+
# reshape the symbolic embedding
|
85 |
+
symb_enc_neg = torch.reshape(symb_enc_neg, (symb_enc_neg.shape[0], symb_enc_neg.shape[1], -1)) # [bs, neg_sp, path_len * emb_dim]
|
86 |
+
|
87 |
+
embeddings_pos.append(symb_enc_pos)
|
88 |
+
embeddings_neg.append(symb_enc_neg)
|
89 |
+
|
90 |
+
|
91 |
+
if len(embeddings_pos) > 1:
|
92 |
+
pos = torch.cat(embeddings_pos, dim=-1)
|
93 |
+
neg = torch.cat(embeddings_neg, dim=-1)
|
94 |
+
else:
|
95 |
+
pos = embeddings_pos[0]
|
96 |
+
neg = embeddings_neg[0]
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
# Feed the concatenated embeddings to the final layer
|
101 |
+
pos = self.fc1(pos)
|
102 |
+
neg = self.fc1(neg)
|
103 |
+
scores_pos = self.fc2(pos)
|
104 |
+
scores_neg = self.fc2(neg)
|
105 |
+
|
106 |
+
return scores_pos, scores_neg
|
107 |
+
|
Reranking/train_eval_path_amazon.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
|
3 |
+
|
4 |
+
import pickle as pkl
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
import wandb
|
9 |
+
import numpy as np
|
10 |
+
import time
|
11 |
+
from torch_scatter import segment_csr, scatter_mean
|
12 |
+
from itertools import product
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.nn import CrossEntropyLoss
|
16 |
+
import random
|
17 |
+
from collections import defaultdict
|
18 |
+
import os
|
19 |
+
|
20 |
+
from Reranking.utils import move_to_cuda, seed_everything
|
21 |
+
from Reranking.rerankers.path import PathReranker
|
22 |
+
from utils import ModelForSTaRKQA
|
23 |
+
from stark_qa import load_qa, load_skb
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import argparse
|
26 |
+
import json
|
27 |
+
import time
|
28 |
+
|
29 |
+
|
30 |
+
seed_everything(42)
|
31 |
+
|
32 |
+
# ***** Dataset *****
|
33 |
+
class TrainDataset(Dataset):
|
34 |
+
"""
|
35 |
+
Custom Dataset for the training data.
|
36 |
+
Each instance contains multiple positive and negative candidates.
|
37 |
+
"""
|
38 |
+
def __init__(self, saved_data, max_neg_candidates=100):
|
39 |
+
"""
|
40 |
+
10s for 1000 data
|
41 |
+
"""
|
42 |
+
print(f"start processing training dataset...")
|
43 |
+
s_time = time.time()
|
44 |
+
self.max_neg_candidates = max_neg_candidates
|
45 |
+
self.sorted_query2neg = defaultdict(list)
|
46 |
+
|
47 |
+
|
48 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
49 |
+
self.data = saved_data['data']
|
50 |
+
|
51 |
+
|
52 |
+
# separage neg and pos, and prepare query, pos pairs
|
53 |
+
new_data = []
|
54 |
+
|
55 |
+
for i in range(len(self.data)):
|
56 |
+
neg_ids = []
|
57 |
+
pos_ids = []
|
58 |
+
item = self.data[i]
|
59 |
+
|
60 |
+
|
61 |
+
candidates_dict = item['pred_dict']
|
62 |
+
ans_ids = item['ans_ids']
|
63 |
+
# pos_ids = ans_ids
|
64 |
+
for ans_id in ans_ids:
|
65 |
+
if ans_id in candidates_dict.keys():
|
66 |
+
pos_ids.append(ans_id)
|
67 |
+
neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
|
68 |
+
|
69 |
+
# load scores vector
|
70 |
+
score_vector_dict = item['score_vector_dict']
|
71 |
+
|
72 |
+
# load the text path, str format
|
73 |
+
text_emb_dict = item['text_emb_dict']
|
74 |
+
|
75 |
+
# load the symb_enc_dict
|
76 |
+
symb_enc_dict = item['symb_enc_dict']
|
77 |
+
|
78 |
+
|
79 |
+
self.data[i]['pos_ids'] = pos_ids
|
80 |
+
self.data[i]['neg_ids'] = neg_ids
|
81 |
+
|
82 |
+
query = item['query']
|
83 |
+
for pos_id in pos_ids:
|
84 |
+
new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
|
85 |
+
|
86 |
+
|
87 |
+
# print(f"new_data: {new_data}")
|
88 |
+
|
89 |
+
neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
|
90 |
+
sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
|
91 |
+
|
92 |
+
|
93 |
+
self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
|
94 |
+
|
95 |
+
|
96 |
+
self.data = new_data
|
97 |
+
print(f"Complete data preparation")
|
98 |
+
print(f"Time: {time.time() - s_time}")
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return len(self.data)
|
105 |
+
|
106 |
+
def __getitem__(self, idx):
|
107 |
+
|
108 |
+
return self.data[idx]
|
109 |
+
|
110 |
+
def collate_batch(self, pairs):
|
111 |
+
s_time = time.time()
|
112 |
+
|
113 |
+
# q
|
114 |
+
batch_q = [pair[0] for pair in pairs] # q is text
|
115 |
+
q_text = batch_q
|
116 |
+
# print(f"q111, {q_text}")
|
117 |
+
|
118 |
+
|
119 |
+
# pos
|
120 |
+
# get the score vector
|
121 |
+
batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
|
122 |
+
batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
|
123 |
+
batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
|
124 |
+
# get the text emb
|
125 |
+
batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
|
126 |
+
batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
|
127 |
+
# get the symb_enc
|
128 |
+
batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
|
129 |
+
batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
|
130 |
+
|
131 |
+
|
132 |
+
# Negative samples
|
133 |
+
batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
|
134 |
+
|
135 |
+
|
136 |
+
# get the score vector
|
137 |
+
batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
|
138 |
+
batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
|
139 |
+
# reshape to [bs, 100, 4]
|
140 |
+
batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
|
141 |
+
batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
|
142 |
+
|
143 |
+
# get the text emb
|
144 |
+
batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
|
145 |
+
batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
|
146 |
+
# reshape to [bs, 100, 768]
|
147 |
+
batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
|
148 |
+
|
149 |
+
# get the symb_enc
|
150 |
+
batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
|
151 |
+
batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
|
152 |
+
# reshape to [bs, 100, 3]
|
153 |
+
batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
# Create a dictionary for the batch
|
160 |
+
feed_dict = {
|
161 |
+
'query': q_text,
|
162 |
+
'p_score_vector': batch_p_score_vector,
|
163 |
+
'p_text_emb': batch_p_text_emb,
|
164 |
+
'p_symb_enc': batch_p_symb_enc,
|
165 |
+
'n_score_vector': batch_n_score_vector,
|
166 |
+
'n_text_emb': batch_n_text_emb,
|
167 |
+
'n_symb_enc': batch_n_symb_enc,
|
168 |
+
|
169 |
+
}
|
170 |
+
|
171 |
+
|
172 |
+
return feed_dict
|
173 |
+
|
174 |
+
|
175 |
+
class TestDataset(Dataset):
|
176 |
+
"""
|
177 |
+
data format: {
|
178 |
+
"query": query,
|
179 |
+
"pred_dict": {node_id: score},
|
180 |
+
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
|
181 |
+
"text_emb_dict": {node_id: text_emb},
|
182 |
+
"ans_ids": [],
|
183 |
+
}
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, saved_data):
|
188 |
+
|
189 |
+
print(f"Start processing test dataset...")
|
190 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
191 |
+
self.data = saved_data['data']
|
192 |
+
|
193 |
+
self.text_emb_matrix = list(self.text2emb_dict.values())
|
194 |
+
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
|
195 |
+
|
196 |
+
# make the mapping between the key of text2emb_dict and the index of text_emb_matrix
|
197 |
+
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
|
198 |
+
|
199 |
+
print(f"Complete data preparation: {len(self.data)}")
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return len(self.data)
|
206 |
+
|
207 |
+
def __getitem__(self, idx):
|
208 |
+
# ***** amazon *****
|
209 |
+
# change from the str to index
|
210 |
+
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
|
211 |
+
|
212 |
+
return self.data[idx]
|
213 |
+
|
214 |
+
def collate_batch(self, batch):
|
215 |
+
|
216 |
+
# q
|
217 |
+
batch_q = [batch[i]['query'] for i in range(len(batch))]
|
218 |
+
q_text = batch_q
|
219 |
+
|
220 |
+
# c
|
221 |
+
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
|
222 |
+
batch_c = torch.tensor(batch_c)
|
223 |
+
# print(f"111, {batch_c.shape}")
|
224 |
+
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
|
225 |
+
c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
|
226 |
+
|
227 |
+
|
228 |
+
# print(f"222, {c_vector.shape}")
|
229 |
+
# c_text_emb
|
230 |
+
# c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
|
231 |
+
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
|
232 |
+
c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
|
233 |
+
|
234 |
+
# c_symb_enc
|
235 |
+
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
|
236 |
+
c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
|
237 |
+
|
238 |
+
|
239 |
+
# ans_ids
|
240 |
+
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
|
241 |
+
|
242 |
+
# pred_ids
|
243 |
+
pred_ids = batch_c.tolist()
|
244 |
+
|
245 |
+
|
246 |
+
# Create a dictionary for the batch
|
247 |
+
feed_dict = {
|
248 |
+
'query': q_text,
|
249 |
+
'c_score_vector': c_score_vector,
|
250 |
+
'c_text_emb': c_text_emb,
|
251 |
+
'c_symb_enc': c_symb_enc,
|
252 |
+
'ans_ids': ans_ids,
|
253 |
+
'pred_ids': pred_ids
|
254 |
+
|
255 |
+
}
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
return feed_dict
|
260 |
+
|
261 |
+
|
262 |
+
# ******* loss function ********
|
263 |
+
def loss_fn(scores_pos, scores_neg):
|
264 |
+
|
265 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
266 |
+
|
267 |
+
# Combine scores
|
268 |
+
scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
|
269 |
+
# print(f"scores: {scores.shape}")
|
270 |
+
|
271 |
+
# Create target
|
272 |
+
target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
|
273 |
+
|
274 |
+
# Compute loss
|
275 |
+
loss = loss_fct(scores, target)
|
276 |
+
|
277 |
+
return loss
|
278 |
+
|
279 |
+
# ***** pairwise loss *****
|
280 |
+
def pairwise_loss(scores_pos, scores_neg, margin=0.5):
|
281 |
+
# scores_pos: [bs, 1]
|
282 |
+
# scores_neg: [bs, 100, 1]
|
283 |
+
|
284 |
+
# Compute loss
|
285 |
+
differences = scores_pos.unsqueeze(1) - scores_neg - margin # [bs, 100, 1]
|
286 |
+
differences = differences.view(-1) # [bs*100]
|
287 |
+
loss = F.relu(-differences).mean() # Standard pairwise loss
|
288 |
+
|
289 |
+
return loss
|
290 |
+
|
291 |
+
|
292 |
+
def batch_evaluator(skb, scores_cand, ans_ids, batch):
|
293 |
+
|
294 |
+
results = {}
|
295 |
+
|
296 |
+
# **** batch wise evaluation ****
|
297 |
+
# evaluate
|
298 |
+
candidates_ids = skb.candidate_ids
|
299 |
+
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
|
300 |
+
|
301 |
+
|
302 |
+
# initialize the pred_matrix
|
303 |
+
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
|
304 |
+
|
305 |
+
|
306 |
+
# get the index of each pred_ids
|
307 |
+
# flatten the pred_ids
|
308 |
+
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
|
309 |
+
|
310 |
+
|
311 |
+
# get the index of each pred_ids
|
312 |
+
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
|
313 |
+
|
314 |
+
|
315 |
+
# reshape the pred_idx
|
316 |
+
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
|
317 |
+
|
318 |
+
# move pred_matrix to the device
|
319 |
+
pred_matrix = pred_matrix.to(scores_cand.device)
|
320 |
+
|
321 |
+
# advanced indexing
|
322 |
+
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
# Create a mapping from candidate IDs to their indices for faster lookup
|
327 |
+
|
328 |
+
|
329 |
+
# Flatten ans_ids to a single list and map them to indices
|
330 |
+
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
|
331 |
+
|
332 |
+
# Create the row indices for ans_matrix corresponding to the answers
|
333 |
+
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
|
334 |
+
|
335 |
+
# Create the answer matrix
|
336 |
+
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
|
337 |
+
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
# batch computing hit1
|
344 |
+
# find the index of the max score
|
345 |
+
max_score, max_idx = torch.max(pred_matrix, dim=1)
|
346 |
+
# check the label of the max idx
|
347 |
+
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
|
348 |
+
hit1_list = batch_hit1.tolist()
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
# batch computing hit@5
|
354 |
+
_, top5_idx = torch.topk(pred_matrix, 5, dim=1)
|
355 |
+
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
|
356 |
+
|
357 |
+
# max with each row
|
358 |
+
batch_hit5 = torch.max(batch_hit5, dim=1)[0]
|
359 |
+
hit5_list = batch_hit5.tolist()
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
# batch computing recall@20
|
364 |
+
_, top20_idx = torch.topk(pred_matrix, 20, dim=1)
|
365 |
+
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
|
366 |
+
# sum with each row
|
367 |
+
batch_recall20 = torch.sum(batch_recall20, dim=1)
|
368 |
+
# divide by the sum of the ans_matrix along the row
|
369 |
+
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
|
370 |
+
recall20_list = batch_recall20.tolist()
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
# batch computing mrr
|
376 |
+
# find the highest rank of the answer
|
377 |
+
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
|
378 |
+
# query the answer matrix with the rank_idx
|
379 |
+
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
|
380 |
+
# find the first rank of the answer
|
381 |
+
batch_mrr = torch.argmax(batch_mrr, dim=1)
|
382 |
+
# add 1 to the rank
|
383 |
+
batch_mrr += 1
|
384 |
+
# divide by the rank
|
385 |
+
batch_mrr = 1 / batch_mrr.float()
|
386 |
+
mrr_list = batch_mrr.tolist()
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
results['hit@1'] = hit1_list
|
392 |
+
results['hit@5'] = hit5_list
|
393 |
+
results['recall@20'] = recall20_list
|
394 |
+
results['mrr'] = mrr_list
|
395 |
+
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
return results
|
402 |
+
|
403 |
+
|
404 |
+
|
405 |
+
# ***** evaluate *****
|
406 |
+
@torch.no_grad()
|
407 |
+
def evaluate(reranker, test_loader):
|
408 |
+
|
409 |
+
|
410 |
+
reranker.eval()
|
411 |
+
|
412 |
+
all_results = {
|
413 |
+
"hit@1": [],
|
414 |
+
"hit@5": [],
|
415 |
+
"recall@20": [],
|
416 |
+
"mrr": []
|
417 |
+
}
|
418 |
+
avg_results = {
|
419 |
+
"hit@1": 0,
|
420 |
+
"hit@5": 0,
|
421 |
+
"recall@20": 0,
|
422 |
+
"mrr": 0
|
423 |
+
}
|
424 |
+
|
425 |
+
|
426 |
+
# save the scores and ans_ids, and pred_ids
|
427 |
+
pred_list = []
|
428 |
+
scores_cand_list = []
|
429 |
+
ans_ids_list = []
|
430 |
+
# use tqdm to show the progress
|
431 |
+
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
|
432 |
+
batch = move_to_cuda(batch)
|
433 |
+
|
434 |
+
# Check if the model is wrapped in DataParallel
|
435 |
+
if isinstance(reranker, nn.DataParallel):
|
436 |
+
scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
|
437 |
+
else:
|
438 |
+
scores_cand = reranker.eval_batch(batch)
|
439 |
+
|
440 |
+
|
441 |
+
# ans_ids
|
442 |
+
ans_ids = batch['ans_ids']
|
443 |
+
|
444 |
+
results = batch_evaluator(skb, scores_cand, ans_ids, batch)
|
445 |
+
|
446 |
+
|
447 |
+
for key in results.keys():
|
448 |
+
all_results[key].extend(results[key])
|
449 |
+
|
450 |
+
# save the scores and ans_ids, and pred_ids
|
451 |
+
pred_list.extend(batch['pred_ids'])
|
452 |
+
scores_cand_list.extend(scores_cand.cpu().tolist())
|
453 |
+
ans_ids_list.extend(ans_ids)
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
for key in avg_results.keys():
|
458 |
+
avg_results[key] = np.mean(all_results[key])
|
459 |
+
|
460 |
+
print(f"Results: {avg_results}")
|
461 |
+
|
462 |
+
|
463 |
+
|
464 |
+
return avg_results
|
465 |
+
|
466 |
+
|
467 |
+
# ***** train *****
|
468 |
+
def main(train_data, val_data, test_data, skb, dataset_name, args):
|
469 |
+
|
470 |
+
|
471 |
+
epochs = args.epochs
|
472 |
+
device = args.device
|
473 |
+
|
474 |
+
train_size = args.train_batch_size
|
475 |
+
test_size = 64
|
476 |
+
|
477 |
+
train_dataset = TrainDataset(train_data)
|
478 |
+
train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
|
479 |
+
|
480 |
+
test_dataset = TestDataset(test_data)
|
481 |
+
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
|
482 |
+
|
483 |
+
val_dataset = TestDataset(val_data)
|
484 |
+
val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
|
485 |
+
|
486 |
+
|
487 |
+
# ***** Model *****
|
488 |
+
reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
|
489 |
+
save_dir = f"./data/checkpoints/{dataset_name}/path"
|
490 |
+
os.makedirs(save_dir, exist_ok=True)
|
491 |
+
|
492 |
+
reranker.to(device)
|
493 |
+
# # parallel processing
|
494 |
+
reranker = nn.DataParallel(reranker)
|
495 |
+
|
496 |
+
|
497 |
+
optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
|
498 |
+
best_val_hit1 = float('-inf')
|
499 |
+
|
500 |
+
|
501 |
+
val_results = evaluate(reranker, val_loader)
|
502 |
+
print(f"Val evaluation")
|
503 |
+
print(val_results)
|
504 |
+
|
505 |
+
|
506 |
+
test_results = evaluate(reranker, test_loader)
|
507 |
+
print(f"Test evaluation")
|
508 |
+
print(test_results)
|
509 |
+
|
510 |
+
# log both val and test results
|
511 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
512 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
|
513 |
+
|
514 |
+
best_test_results = {}
|
515 |
+
for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
|
516 |
+
total_loss = 0.0
|
517 |
+
reranker.train()
|
518 |
+
count = 0
|
519 |
+
total_instances = 0
|
520 |
+
|
521 |
+
for batch in tqdm(train_loader):
|
522 |
+
# print(batch)
|
523 |
+
batch = move_to_cuda(batch)
|
524 |
+
# print(batch)
|
525 |
+
|
526 |
+
scores_pos, scores_neg = reranker(batch)
|
527 |
+
|
528 |
+
# batch_loss = pairwise_loss(scores_pos, scores_neg)
|
529 |
+
batch_loss = loss_fn(scores_pos, scores_neg)
|
530 |
+
|
531 |
+
# clear optimizer
|
532 |
+
optimizer.zero_grad()
|
533 |
+
batch_loss.backward()
|
534 |
+
optimizer.step()
|
535 |
+
|
536 |
+
# total_loss += batch_loss.item()
|
537 |
+
count += 1
|
538 |
+
# compute the average loss
|
539 |
+
total_instances += scores_pos.shape[0]
|
540 |
+
total_loss += batch_loss.item()
|
541 |
+
|
542 |
+
|
543 |
+
train_loss = total_loss / total_instances
|
544 |
+
|
545 |
+
print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
|
546 |
+
|
547 |
+
|
548 |
+
|
549 |
+
val_results = evaluate(reranker, val_loader)
|
550 |
+
print(f"Val evaluation")
|
551 |
+
print(val_results)
|
552 |
+
|
553 |
+
|
554 |
+
test_results = evaluate(reranker, test_loader)
|
555 |
+
print(f"Test evaluation")
|
556 |
+
print(test_results)
|
557 |
+
|
558 |
+
# log both val and test results
|
559 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
560 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
|
561 |
+
'train_loss': train_loss})
|
562 |
+
|
563 |
+
|
564 |
+
# save the best model when val hit1 is the highest
|
565 |
+
hit1 = val_results['hit@1']
|
566 |
+
if best_val_hit1 < hit1:
|
567 |
+
best_val_hit1 = hit1
|
568 |
+
|
569 |
+
save_path = f"{save_dir}/best_{best_val_hit1}.pth"
|
570 |
+
|
571 |
+
if isinstance(reranker, nn.DataParallel):
|
572 |
+
torch.save(reranker.module.state_dict(), save_path)
|
573 |
+
else:
|
574 |
+
torch.save(reranker.state_dict(), save_path)
|
575 |
+
print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
|
576 |
+
|
577 |
+
args.checkpoint_path = save_path
|
578 |
+
best_test_results = test_results
|
579 |
+
|
580 |
+
|
581 |
+
|
582 |
+
# save last epoch checkopint
|
583 |
+
save_path = f"{save_dir}/last_{hit1}.pth"
|
584 |
+
if isinstance(reranker, nn.DataParallel):
|
585 |
+
torch.save(reranker.module.state_dict(), save_path)
|
586 |
+
else:
|
587 |
+
torch.save(reranker.state_dict(), save_path)
|
588 |
+
print(f"Final checkpoint saved at {save_path}")
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
# ***** save the results *****
|
593 |
+
results = []
|
594 |
+
results.append(
|
595 |
+
{
|
596 |
+
"config": vars(args),
|
597 |
+
"test_results": best_test_results
|
598 |
+
}
|
599 |
+
)
|
600 |
+
# save the results to json
|
601 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
602 |
+
output_dir = f"./data/outputs/{dataset_name}"
|
603 |
+
os.makedirs(output_dir, exist_ok=True)
|
604 |
+
with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
|
605 |
+
json.dump(results, f, indent=4)
|
606 |
+
|
607 |
+
print(best_test_results)
|
608 |
+
|
609 |
+
|
610 |
+
def get_concat_num(combo):
|
611 |
+
"""
|
612 |
+
Determine the value of concat_num based on the combination of embeddings.
|
613 |
+
- score_vec adds +1
|
614 |
+
- text_emb adds +1
|
615 |
+
- symb_enc adds +3
|
616 |
+
"""
|
617 |
+
concat_num = 0
|
618 |
+
if combo.get("score_vec", False): # If score_vec is True
|
619 |
+
concat_num += 1
|
620 |
+
if combo.get("text_emb", False): # If text_emb is True
|
621 |
+
concat_num += 1
|
622 |
+
if combo.get("symb_enc", False): # If symb_enc is True
|
623 |
+
concat_num += 3
|
624 |
+
|
625 |
+
|
626 |
+
return concat_num
|
627 |
+
|
628 |
+
|
629 |
+
|
630 |
+
def parse_args():
|
631 |
+
|
632 |
+
parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
|
633 |
+
|
634 |
+
# Add arguments for model configurations
|
635 |
+
parser.add_argument("--train_batch_size", type=int, default=64, help="Batch size for training or evaluation.")
|
636 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for optimizer.")
|
637 |
+
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
|
638 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
|
639 |
+
|
640 |
+
# Add arguments for the dataset
|
641 |
+
parser.add_argument("--dataset_name", type=str, default="amazon", help="Name of the dataset to use.")
|
642 |
+
# paths
|
643 |
+
parser.add_argument("--train_path", type=str, default=f"../amazon_train.pkl", help="Path to the training data.")
|
644 |
+
parser.add_argument("--test_path", type=str, default=f"../amazon_test.pkl", help="Path to the test data.")
|
645 |
+
parser.add_argument("--val_path", type=str, default=f"../amazon_val.pkl", help="Path to the validation data.")
|
646 |
+
|
647 |
+
# add concat_num
|
648 |
+
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
|
649 |
+
|
650 |
+
# checkpoint save path
|
651 |
+
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
|
652 |
+
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
|
653 |
+
|
654 |
+
|
655 |
+
# Parse the base arguments
|
656 |
+
args = parser.parse_args()
|
657 |
+
return args
|
658 |
+
|
659 |
+
|
660 |
+
if __name__ == "__main__":
|
661 |
+
|
662 |
+
base_args = parse_args()
|
663 |
+
test_path = base_args.test_path
|
664 |
+
train_path = base_args.train_path
|
665 |
+
val_path = base_args.val_path
|
666 |
+
dataset_name = base_args.dataset_name
|
667 |
+
|
668 |
+
with open(test_path, "rb") as f:
|
669 |
+
test_data = pkl.load(f)
|
670 |
+
|
671 |
+
with open(train_path, "rb") as f:
|
672 |
+
train_data = pkl.load(f)
|
673 |
+
|
674 |
+
with open(val_path, "rb") as f:
|
675 |
+
val_data = pkl.load(f)
|
676 |
+
|
677 |
+
# load skb
|
678 |
+
skb = load_skb(dataset_name)
|
679 |
+
|
680 |
+
# set all
|
681 |
+
combo = {
|
682 |
+
"text_emb": True,
|
683 |
+
"score_vec": True,
|
684 |
+
"symb_enc": True
|
685 |
+
}
|
686 |
+
concat_num = get_concat_num(combo)
|
687 |
+
|
688 |
+
wandb.init(project=f'reranking-{dataset_name}', name=f"path")
|
689 |
+
args = argparse.Namespace(**vars(base_args), **combo)
|
690 |
+
args.concat_num = concat_num
|
691 |
+
|
692 |
+
|
693 |
+
main(train_data, val_data, test_data, skb, dataset_name, args)
|
694 |
+
|
Reranking/train_eval_path_mag.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
|
3 |
+
|
4 |
+
import pickle as pkl
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
import wandb
|
9 |
+
import numpy as np
|
10 |
+
import time
|
11 |
+
from torch_scatter import segment_csr, scatter_mean
|
12 |
+
from itertools import product
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.nn import CrossEntropyLoss
|
16 |
+
import random
|
17 |
+
from collections import defaultdict
|
18 |
+
import os
|
19 |
+
|
20 |
+
from Reranking.utils import move_to_cuda, seed_everything
|
21 |
+
from Reranking.rerankers.path import PathReranker
|
22 |
+
from utils import ModelForSTaRKQA
|
23 |
+
from stark_qa import load_qa, load_skb
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import argparse
|
26 |
+
import json
|
27 |
+
import time
|
28 |
+
|
29 |
+
seed_everything(42)
|
30 |
+
|
31 |
+
# ***** Dataset *****
|
32 |
+
class TrainDataset(Dataset):
|
33 |
+
"""
|
34 |
+
Custom Dataset for the training data.
|
35 |
+
Each instance contains multiple positive and negative candidates.
|
36 |
+
"""
|
37 |
+
def __init__(self, saved_data, max_neg_candidates=100):
|
38 |
+
"""
|
39 |
+
10s for 1000 data
|
40 |
+
"""
|
41 |
+
print(f"start processing training dataset...")
|
42 |
+
s_time = time.time()
|
43 |
+
self.max_neg_candidates = max_neg_candidates
|
44 |
+
self.sorted_query2neg = defaultdict(list)
|
45 |
+
|
46 |
+
|
47 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
48 |
+
self.data = saved_data['data']
|
49 |
+
|
50 |
+
|
51 |
+
# separage neg and pos, and prepare query, pos pairs
|
52 |
+
new_data = []
|
53 |
+
|
54 |
+
for i in range(len(self.data)):
|
55 |
+
neg_ids = []
|
56 |
+
pos_ids = []
|
57 |
+
item = self.data[i]
|
58 |
+
|
59 |
+
|
60 |
+
candidates_dict = item['pred_dict']
|
61 |
+
ans_ids = item['ans_ids']
|
62 |
+
# pos_ids = ans_ids
|
63 |
+
for ans_id in ans_ids:
|
64 |
+
if ans_id in candidates_dict.keys():
|
65 |
+
pos_ids.append(ans_id)
|
66 |
+
neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
|
67 |
+
|
68 |
+
# load scores vector
|
69 |
+
score_vector_dict = item['score_vector_dict']
|
70 |
+
|
71 |
+
# load the text path, str format
|
72 |
+
text_emb_dict = item['text_emb_dict']
|
73 |
+
|
74 |
+
# load the symb_enc_dict
|
75 |
+
symb_enc_dict = item['symb_enc_dict']
|
76 |
+
|
77 |
+
|
78 |
+
self.data[i]['pos_ids'] = pos_ids
|
79 |
+
self.data[i]['neg_ids'] = neg_ids
|
80 |
+
|
81 |
+
query = item['query']
|
82 |
+
for pos_id in pos_ids:
|
83 |
+
new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
|
84 |
+
|
85 |
+
|
86 |
+
# print(f"new_data: {new_data}")
|
87 |
+
|
88 |
+
neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
|
89 |
+
sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
|
90 |
+
|
91 |
+
|
92 |
+
self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
|
93 |
+
|
94 |
+
|
95 |
+
self.data = new_data
|
96 |
+
print(f"Complete data preparation")
|
97 |
+
print(f"Time: {time.time() - s_time}")
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.data)
|
104 |
+
|
105 |
+
def __getitem__(self, idx):
|
106 |
+
|
107 |
+
return self.data[idx]
|
108 |
+
|
109 |
+
def collate_batch(self, pairs):
|
110 |
+
s_time = time.time()
|
111 |
+
|
112 |
+
# q
|
113 |
+
batch_q = [pair[0] for pair in pairs] # q is text
|
114 |
+
q_text = batch_q
|
115 |
+
# print(f"q111, {q_text}")
|
116 |
+
|
117 |
+
|
118 |
+
# pos
|
119 |
+
# get the score vector
|
120 |
+
batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
|
121 |
+
batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
|
122 |
+
batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
|
123 |
+
# get the text emb
|
124 |
+
batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
|
125 |
+
batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
|
126 |
+
# get the symb_enc
|
127 |
+
batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
|
128 |
+
batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
|
129 |
+
|
130 |
+
|
131 |
+
# Negative samples
|
132 |
+
batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
|
133 |
+
|
134 |
+
|
135 |
+
# get the score vector
|
136 |
+
batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
|
137 |
+
batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
|
138 |
+
# reshape to [bs, 100, 4]
|
139 |
+
batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
|
140 |
+
batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
|
141 |
+
|
142 |
+
# get the text emb
|
143 |
+
batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
|
144 |
+
batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
|
145 |
+
# reshape to [bs, 100, 768]
|
146 |
+
batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
|
147 |
+
|
148 |
+
# get the symb_enc
|
149 |
+
batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
|
150 |
+
batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
|
151 |
+
# reshape to [bs, 100, 3]
|
152 |
+
batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
# Create a dictionary for the batch
|
159 |
+
feed_dict = {
|
160 |
+
'query': q_text,
|
161 |
+
'p_score_vector': batch_p_score_vector,
|
162 |
+
'p_text_emb': batch_p_text_emb,
|
163 |
+
'p_symb_enc': batch_p_symb_enc,
|
164 |
+
'n_score_vector': batch_n_score_vector,
|
165 |
+
'n_text_emb': batch_n_text_emb,
|
166 |
+
'n_symb_enc': batch_n_symb_enc,
|
167 |
+
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
return feed_dict
|
172 |
+
|
173 |
+
|
174 |
+
class TestDataset(Dataset):
|
175 |
+
"""
|
176 |
+
data format: {
|
177 |
+
"query": query,
|
178 |
+
"pred_dict": {node_id: score},
|
179 |
+
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
|
180 |
+
"text_emb_dict": {node_id: text_emb},
|
181 |
+
"ans_ids": [],
|
182 |
+
}
|
183 |
+
|
184 |
+
"""
|
185 |
+
|
186 |
+
def __init__(self, saved_data):
|
187 |
+
|
188 |
+
print(f"Start processing test dataset...")
|
189 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
190 |
+
self.data = saved_data['data']
|
191 |
+
|
192 |
+
self.text_emb_matrix = list(self.text2emb_dict.values())
|
193 |
+
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
|
194 |
+
|
195 |
+
# make the mapping between the key of text2emb_dict and the index of text_emb_matrix
|
196 |
+
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
|
197 |
+
|
198 |
+
print(f"Complete data preparation: {len(self.data)}")
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
def __len__(self):
|
204 |
+
return len(self.data)
|
205 |
+
|
206 |
+
def __getitem__(self, idx):
|
207 |
+
|
208 |
+
|
209 |
+
# sort the pred_dict by the score
|
210 |
+
pred_dict = self.data[idx]['pred_dict']
|
211 |
+
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
|
212 |
+
# get the top 50 candidates
|
213 |
+
sorted_ids = sorted_ids[:50]
|
214 |
+
# get the score vector
|
215 |
+
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
|
216 |
+
# get the symb_enc_dict
|
217 |
+
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
|
218 |
+
# change from the str to index
|
219 |
+
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
|
220 |
+
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
|
221 |
+
|
222 |
+
return self.data[idx]
|
223 |
+
|
224 |
+
def collate_batch(self, batch):
|
225 |
+
|
226 |
+
# q
|
227 |
+
batch_q = [batch[i]['query'] for i in range(len(batch))]
|
228 |
+
q_text = batch_q
|
229 |
+
|
230 |
+
# c
|
231 |
+
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
|
232 |
+
batch_c = torch.tensor(batch_c)
|
233 |
+
# print(f"111, {batch_c.shape}")
|
234 |
+
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
|
235 |
+
c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
|
236 |
+
|
237 |
+
|
238 |
+
# print(f"222, {c_vector.shape}")
|
239 |
+
# c_text_emb
|
240 |
+
# c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
|
241 |
+
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
|
242 |
+
c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
|
243 |
+
|
244 |
+
# c_symb_enc
|
245 |
+
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
|
246 |
+
c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
|
247 |
+
|
248 |
+
|
249 |
+
# ans_ids
|
250 |
+
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
|
251 |
+
|
252 |
+
# pred_ids
|
253 |
+
pred_ids = batch_c.tolist()
|
254 |
+
|
255 |
+
|
256 |
+
# Create a dictionary for the batch
|
257 |
+
feed_dict = {
|
258 |
+
'query': q_text,
|
259 |
+
'c_score_vector': c_score_vector,
|
260 |
+
'c_text_emb': c_text_emb,
|
261 |
+
'c_symb_enc': c_symb_enc,
|
262 |
+
'ans_ids': ans_ids,
|
263 |
+
'pred_ids': pred_ids
|
264 |
+
|
265 |
+
}
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
return feed_dict
|
270 |
+
|
271 |
+
|
272 |
+
# ******* loss function ********
|
273 |
+
def loss_fn(scores_pos, scores_neg):
|
274 |
+
|
275 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
276 |
+
|
277 |
+
# Combine scores
|
278 |
+
scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
|
279 |
+
# print(f"scores: {scores.shape}")
|
280 |
+
|
281 |
+
# Create target
|
282 |
+
target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
|
283 |
+
|
284 |
+
# Compute loss
|
285 |
+
loss = loss_fct(scores, target)
|
286 |
+
|
287 |
+
return loss
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
def batch_evaluator(skb, scores_cand, ans_ids, batch):
|
292 |
+
|
293 |
+
results = {}
|
294 |
+
|
295 |
+
# **** batch wise evaluation ****
|
296 |
+
# evaluate
|
297 |
+
candidates_ids = skb.candidate_ids
|
298 |
+
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
|
299 |
+
|
300 |
+
|
301 |
+
# initialize the pred_matrix
|
302 |
+
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
|
303 |
+
|
304 |
+
|
305 |
+
# get the index of each pred_ids
|
306 |
+
# flatten the pred_ids
|
307 |
+
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
|
308 |
+
|
309 |
+
|
310 |
+
# get the index of each pred_ids
|
311 |
+
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
|
312 |
+
|
313 |
+
|
314 |
+
# reshape the pred_idx
|
315 |
+
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
|
316 |
+
|
317 |
+
# move pred_matrix to the device
|
318 |
+
pred_matrix = pred_matrix.to(scores_cand.device)
|
319 |
+
|
320 |
+
# advanced indexing
|
321 |
+
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
# Create a mapping from candidate IDs to their indices for faster lookup
|
326 |
+
|
327 |
+
|
328 |
+
# Flatten ans_ids to a single list and map them to indices
|
329 |
+
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
|
330 |
+
|
331 |
+
# Create the row indices for ans_matrix corresponding to the answers
|
332 |
+
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
|
333 |
+
|
334 |
+
# Create the answer matrix
|
335 |
+
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
|
336 |
+
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
# batch computing hit1
|
343 |
+
# find the index of the max score
|
344 |
+
max_score, max_idx = torch.max(pred_matrix, dim=1)
|
345 |
+
# check the label of the max idx
|
346 |
+
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
|
347 |
+
hit1_list = batch_hit1.tolist()
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
# batch computing hit@5
|
353 |
+
_, top5_idx = torch.topk(pred_matrix, 5, dim=1)
|
354 |
+
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
|
355 |
+
|
356 |
+
# max with each row
|
357 |
+
batch_hit5 = torch.max(batch_hit5, dim=1)[0]
|
358 |
+
hit5_list = batch_hit5.tolist()
|
359 |
+
|
360 |
+
|
361 |
+
|
362 |
+
# batch computing recall@20
|
363 |
+
_, top20_idx = torch.topk(pred_matrix, 20, dim=1)
|
364 |
+
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
|
365 |
+
# sum with each row
|
366 |
+
batch_recall20 = torch.sum(batch_recall20, dim=1)
|
367 |
+
# divide by the sum of the ans_matrix along the row
|
368 |
+
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
|
369 |
+
recall20_list = batch_recall20.tolist()
|
370 |
+
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
# batch computing mrr
|
375 |
+
# find the highest rank of the answer
|
376 |
+
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
|
377 |
+
# query the answer matrix with the rank_idx
|
378 |
+
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
|
379 |
+
# find the first rank of the answer
|
380 |
+
batch_mrr = torch.argmax(batch_mrr, dim=1)
|
381 |
+
# add 1 to the rank
|
382 |
+
batch_mrr += 1
|
383 |
+
# divide by the rank
|
384 |
+
batch_mrr = 1 / batch_mrr.float()
|
385 |
+
mrr_list = batch_mrr.tolist()
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
results['hit@1'] = hit1_list
|
391 |
+
results['hit@5'] = hit5_list
|
392 |
+
results['recall@20'] = recall20_list
|
393 |
+
results['mrr'] = mrr_list
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
return results
|
401 |
+
|
402 |
+
|
403 |
+
|
404 |
+
# ***** evaluate *****
|
405 |
+
@torch.no_grad()
|
406 |
+
def evaluate(reranker, test_loader):
|
407 |
+
|
408 |
+
|
409 |
+
reranker.eval()
|
410 |
+
|
411 |
+
all_results = {
|
412 |
+
"hit@1": [],
|
413 |
+
"hit@5": [],
|
414 |
+
"recall@20": [],
|
415 |
+
"mrr": []
|
416 |
+
}
|
417 |
+
avg_results = {
|
418 |
+
"hit@1": 0,
|
419 |
+
"hit@5": 0,
|
420 |
+
"recall@20": 0,
|
421 |
+
"mrr": 0
|
422 |
+
}
|
423 |
+
|
424 |
+
|
425 |
+
# save the scores and ans_ids, and pred_ids
|
426 |
+
pred_list = []
|
427 |
+
scores_cand_list = []
|
428 |
+
ans_ids_list = []
|
429 |
+
# use tqdm to show the progress
|
430 |
+
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
|
431 |
+
batch = move_to_cuda(batch)
|
432 |
+
|
433 |
+
# Check if the model is wrapped in DataParallel
|
434 |
+
if isinstance(reranker, nn.DataParallel):
|
435 |
+
scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
|
436 |
+
else:
|
437 |
+
scores_cand = reranker.eval_batch(batch)
|
438 |
+
|
439 |
+
|
440 |
+
# ans_ids
|
441 |
+
ans_ids = batch['ans_ids']
|
442 |
+
|
443 |
+
results = batch_evaluator(skb, scores_cand, ans_ids, batch)
|
444 |
+
|
445 |
+
|
446 |
+
for key in results.keys():
|
447 |
+
all_results[key].extend(results[key])
|
448 |
+
|
449 |
+
# save the scores and ans_ids, and pred_ids
|
450 |
+
pred_list.extend(batch['pred_ids'])
|
451 |
+
scores_cand_list.extend(scores_cand.cpu().tolist())
|
452 |
+
ans_ids_list.extend(ans_ids)
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
for key in avg_results.keys():
|
457 |
+
avg_results[key] = np.mean(all_results[key])
|
458 |
+
|
459 |
+
print(f"Results: {avg_results}")
|
460 |
+
|
461 |
+
|
462 |
+
|
463 |
+
return avg_results
|
464 |
+
|
465 |
+
|
466 |
+
# ***** train *****
|
467 |
+
def main(train_data, val_data, test_data, skb, dataset_name, args):
|
468 |
+
|
469 |
+
|
470 |
+
epochs = args.epochs
|
471 |
+
device = args.device
|
472 |
+
|
473 |
+
train_size = args.train_batch_size
|
474 |
+
test_size = 64
|
475 |
+
|
476 |
+
train_dataset = TrainDataset(train_data)
|
477 |
+
train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
|
478 |
+
|
479 |
+
test_dataset = TestDataset(test_data)
|
480 |
+
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
|
481 |
+
|
482 |
+
val_dataset = TestDataset(val_data)
|
483 |
+
val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
|
484 |
+
|
485 |
+
|
486 |
+
# ***** Model *****
|
487 |
+
reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
|
488 |
+
save_dir = f"./data/checkpoints/{dataset_name}/path"
|
489 |
+
os.makedirs(save_dir, exist_ok=True)
|
490 |
+
|
491 |
+
reranker.to(device)
|
492 |
+
# # parallel processing
|
493 |
+
reranker = nn.DataParallel(reranker)
|
494 |
+
|
495 |
+
|
496 |
+
optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
|
497 |
+
best_val_hit1 = float('-inf')
|
498 |
+
|
499 |
+
|
500 |
+
val_results = evaluate(reranker, val_loader)
|
501 |
+
print(f"Val evaluation")
|
502 |
+
print(val_results)
|
503 |
+
|
504 |
+
|
505 |
+
test_results = evaluate(reranker, test_loader)
|
506 |
+
print(f"Test evaluation")
|
507 |
+
print(test_results)
|
508 |
+
|
509 |
+
# log both val and test results
|
510 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
511 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
|
512 |
+
|
513 |
+
best_test_results = {}
|
514 |
+
for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
|
515 |
+
total_loss = 0.0
|
516 |
+
reranker.train()
|
517 |
+
count = 0
|
518 |
+
total_instances = 0
|
519 |
+
|
520 |
+
for batch in tqdm(train_loader):
|
521 |
+
# print(batch)
|
522 |
+
batch = move_to_cuda(batch)
|
523 |
+
# print(batch)
|
524 |
+
|
525 |
+
scores_pos, scores_neg = reranker(batch)
|
526 |
+
|
527 |
+
# batch_loss = pairwise_loss(scores_pos, scores_neg)
|
528 |
+
batch_loss = loss_fn(scores_pos, scores_neg)
|
529 |
+
|
530 |
+
# clear optimizer
|
531 |
+
optimizer.zero_grad()
|
532 |
+
batch_loss.backward()
|
533 |
+
optimizer.step()
|
534 |
+
|
535 |
+
# total_loss += batch_loss.item()
|
536 |
+
count += 1
|
537 |
+
# compute the average loss
|
538 |
+
total_instances += scores_pos.shape[0]
|
539 |
+
total_loss += batch_loss.item()
|
540 |
+
|
541 |
+
|
542 |
+
train_loss = total_loss / total_instances
|
543 |
+
|
544 |
+
print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
|
545 |
+
|
546 |
+
|
547 |
+
|
548 |
+
val_results = evaluate(reranker, val_loader)
|
549 |
+
print(f"Val evaluation")
|
550 |
+
print(val_results)
|
551 |
+
|
552 |
+
|
553 |
+
test_results = evaluate(reranker, test_loader)
|
554 |
+
print(f"Test evaluation")
|
555 |
+
print(test_results)
|
556 |
+
|
557 |
+
# log both val and test results
|
558 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
559 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
|
560 |
+
'train_loss': train_loss})
|
561 |
+
|
562 |
+
|
563 |
+
# save the best model when val hit1 is the highest
|
564 |
+
hit1 = val_results['hit@1']
|
565 |
+
if best_val_hit1 < hit1:
|
566 |
+
best_val_hit1 = hit1
|
567 |
+
|
568 |
+
save_path = f"{save_dir}/best_{best_val_hit1}.pth"
|
569 |
+
|
570 |
+
if isinstance(reranker, nn.DataParallel):
|
571 |
+
torch.save(reranker.module.state_dict(), save_path)
|
572 |
+
else:
|
573 |
+
torch.save(reranker.state_dict(), save_path)
|
574 |
+
print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
|
575 |
+
|
576 |
+
args.checkpoint_path = save_path
|
577 |
+
best_test_results = test_results
|
578 |
+
|
579 |
+
|
580 |
+
|
581 |
+
# save last epoch checkopint
|
582 |
+
save_path = f"{save_dir}/last_{hit1}.pth"
|
583 |
+
if isinstance(reranker, nn.DataParallel):
|
584 |
+
torch.save(reranker.module.state_dict(), save_path)
|
585 |
+
else:
|
586 |
+
torch.save(reranker.state_dict(), save_path)
|
587 |
+
print(f"Final checkpoint saved at {save_path}")
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
# ***** save the results *****
|
592 |
+
results = []
|
593 |
+
results.append(
|
594 |
+
{
|
595 |
+
"config": vars(args),
|
596 |
+
"test_results": best_test_results
|
597 |
+
}
|
598 |
+
)
|
599 |
+
# save the results to json
|
600 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
601 |
+
output_dir = f"./data/outputs/{dataset_name}"
|
602 |
+
os.makedirs(output_dir, exist_ok=True)
|
603 |
+
with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
|
604 |
+
json.dump(results, f, indent=4)
|
605 |
+
|
606 |
+
print(best_test_results)
|
607 |
+
|
608 |
+
|
609 |
+
|
610 |
+
def get_concat_num(combo):
|
611 |
+
"""
|
612 |
+
Determine the value of concat_num based on the combination of embeddings.
|
613 |
+
- score_vec adds +1
|
614 |
+
- text_emb adds +1
|
615 |
+
- symb_enc adds +3
|
616 |
+
"""
|
617 |
+
concat_num = 0
|
618 |
+
if combo.get("score_vec", False): # If score_vec is True
|
619 |
+
concat_num += 1
|
620 |
+
if combo.get("text_emb", False): # If text_emb is True
|
621 |
+
concat_num += 1
|
622 |
+
if combo.get("symb_enc", False): # If symb_enc is True
|
623 |
+
concat_num += 3
|
624 |
+
|
625 |
+
|
626 |
+
return concat_num
|
627 |
+
|
628 |
+
|
629 |
+
|
630 |
+
def parse_args():
|
631 |
+
|
632 |
+
parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
|
633 |
+
|
634 |
+
# Add arguments for model configurations
|
635 |
+
parser.add_argument("--train_batch_size", type=int, default=5, help="Batch size for training or evaluation.")
|
636 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for optimizer.")
|
637 |
+
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
|
638 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
|
639 |
+
|
640 |
+
# Add arguments for the dataset
|
641 |
+
parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset to use.")
|
642 |
+
# paths
|
643 |
+
parser.add_argument("--train_path", type=str, default=f"../mag_train.pkl", help="Path to the training data.")
|
644 |
+
parser.add_argument("--test_path", type=str, default=f"../mag_test.pkl", help="Path to the test data.")
|
645 |
+
parser.add_argument("--val_path", type=str, default=f"../mag_val.pkl", help="Path to the validation data.")
|
646 |
+
|
647 |
+
# add concat_num
|
648 |
+
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
|
649 |
+
|
650 |
+
# checkpoint save path
|
651 |
+
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
|
652 |
+
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
|
653 |
+
|
654 |
+
|
655 |
+
# Parse the base arguments
|
656 |
+
args = parser.parse_args()
|
657 |
+
return args
|
658 |
+
|
659 |
+
|
660 |
+
if __name__ == "__main__":
|
661 |
+
|
662 |
+
base_args = parse_args()
|
663 |
+
|
664 |
+
test_path = base_args.test_path
|
665 |
+
train_path = base_args.train_path
|
666 |
+
val_path = base_args.val_path
|
667 |
+
dataset_name = base_args.dataset_name
|
668 |
+
|
669 |
+
with open(test_path, "rb") as f:
|
670 |
+
test_data = pkl.load(f)
|
671 |
+
|
672 |
+
with open(train_path, "rb") as f:
|
673 |
+
train_data = pkl.load(f)
|
674 |
+
|
675 |
+
with open(val_path, "rb") as f:
|
676 |
+
val_data = pkl.load(f)
|
677 |
+
|
678 |
+
# load skb
|
679 |
+
skb = load_skb(dataset_name)
|
680 |
+
|
681 |
+
# set all
|
682 |
+
combo = {
|
683 |
+
"text_emb": True,
|
684 |
+
"score_vec": True,
|
685 |
+
"symb_enc": True
|
686 |
+
}
|
687 |
+
concat_num = get_concat_num(combo)
|
688 |
+
|
689 |
+
wandb.init(project=f'Reranking-{dataset_name}', name=f"path")
|
690 |
+
args = argparse.Namespace(**vars(base_args), **combo)
|
691 |
+
args.concat_num = concat_num
|
692 |
+
|
693 |
+
main(train_data, val_data, test_data, skb, dataset_name, args)
|
694 |
+
|
Reranking/train_eval_path_prime.py
ADDED
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
|
4 |
+
|
5 |
+
import pickle as pkl
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import wandb
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import CrossEntropyLoss
|
14 |
+
import random
|
15 |
+
from collections import defaultdict
|
16 |
+
|
17 |
+
|
18 |
+
from Reranking.utils import move_to_cuda, seed_everything
|
19 |
+
from Reranking.rerankers.path import PathReranker
|
20 |
+
from stark_qa import load_qa, load_skb
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import argparse
|
23 |
+
import json
|
24 |
+
import time
|
25 |
+
|
26 |
+
# set the seed
|
27 |
+
seed_everything(42)
|
28 |
+
|
29 |
+
# ***** Dataset *****
|
30 |
+
class TrainDataset(Dataset):
|
31 |
+
"""
|
32 |
+
Custom Dataset for the training data.
|
33 |
+
Each instance contains multiple positive and negative candidates.
|
34 |
+
"""
|
35 |
+
def __init__(self, saved_data, max_neg_candidates=100):
|
36 |
+
"""
|
37 |
+
10s for 1000 data
|
38 |
+
"""
|
39 |
+
print(f"start processing training dataset...")
|
40 |
+
s_time = time.time()
|
41 |
+
self.max_neg_candidates = max_neg_candidates
|
42 |
+
self.sorted_query2neg = defaultdict(list)
|
43 |
+
|
44 |
+
|
45 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
46 |
+
self.data = saved_data['data']
|
47 |
+
|
48 |
+
|
49 |
+
# separage neg and pos, and prepare query, pos pairs
|
50 |
+
new_data = []
|
51 |
+
|
52 |
+
for i in range(len(self.data)):
|
53 |
+
neg_ids = []
|
54 |
+
pos_ids = []
|
55 |
+
item = self.data[i]
|
56 |
+
|
57 |
+
|
58 |
+
candidates_dict = item['pred_dict']
|
59 |
+
ans_ids = item['ans_ids']
|
60 |
+
# pos_ids = ans_ids
|
61 |
+
for ans_id in ans_ids:
|
62 |
+
if ans_id in candidates_dict.keys():
|
63 |
+
pos_ids.append(ans_id)
|
64 |
+
neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
|
65 |
+
|
66 |
+
# load scores vector
|
67 |
+
score_vector_dict = item['score_vector_dict']
|
68 |
+
|
69 |
+
# load the text path, str format
|
70 |
+
text_emb_dict = item['text_emb_dict']
|
71 |
+
|
72 |
+
# load the symb_enc_dict
|
73 |
+
symb_enc_dict = item['symb_enc_dict']
|
74 |
+
|
75 |
+
|
76 |
+
self.data[i]['pos_ids'] = pos_ids
|
77 |
+
self.data[i]['neg_ids'] = neg_ids
|
78 |
+
|
79 |
+
query = item['query']
|
80 |
+
for pos_id in pos_ids:
|
81 |
+
new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
|
82 |
+
|
83 |
+
|
84 |
+
# print(f"new_data: {new_data}")
|
85 |
+
|
86 |
+
neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
|
87 |
+
sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
|
88 |
+
|
89 |
+
|
90 |
+
self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
|
91 |
+
|
92 |
+
|
93 |
+
self.data = new_data
|
94 |
+
print(f"Complete data preparation")
|
95 |
+
print(f"Time: {time.time() - s_time}")
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.data)
|
102 |
+
|
103 |
+
def __getitem__(self, idx):
|
104 |
+
|
105 |
+
return self.data[idx]
|
106 |
+
|
107 |
+
def collate_batch(self, pairs):
|
108 |
+
s_time = time.time()
|
109 |
+
|
110 |
+
# q
|
111 |
+
batch_q = [pair[0] for pair in pairs] # q is text
|
112 |
+
q_text = batch_q
|
113 |
+
# print(f"q111, {q_text}")
|
114 |
+
|
115 |
+
|
116 |
+
# pos
|
117 |
+
# get the score vector
|
118 |
+
batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
|
119 |
+
batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
|
120 |
+
batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
|
121 |
+
# get the text emb
|
122 |
+
batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
|
123 |
+
batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
|
124 |
+
# get the symb_enc
|
125 |
+
batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
|
126 |
+
batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
|
127 |
+
|
128 |
+
|
129 |
+
# Negative samples
|
130 |
+
batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
|
131 |
+
|
132 |
+
|
133 |
+
# get the score vector
|
134 |
+
batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
|
135 |
+
batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
|
136 |
+
# reshape to [bs, 100, 4]
|
137 |
+
batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
|
138 |
+
batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
|
139 |
+
|
140 |
+
# get the text emb
|
141 |
+
batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
|
142 |
+
batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
|
143 |
+
# reshape to [bs, 100, 768]
|
144 |
+
batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
|
145 |
+
|
146 |
+
# get the symb_enc
|
147 |
+
batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
|
148 |
+
batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
|
149 |
+
# reshape to [bs, 100, 3]
|
150 |
+
batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
# Create a dictionary for the batch
|
157 |
+
feed_dict = {
|
158 |
+
'query': q_text,
|
159 |
+
'p_score_vector': batch_p_score_vector,
|
160 |
+
'p_text_emb': batch_p_text_emb,
|
161 |
+
'p_symb_enc': batch_p_symb_enc,
|
162 |
+
'n_score_vector': batch_n_score_vector,
|
163 |
+
'n_text_emb': batch_n_text_emb,
|
164 |
+
'n_symb_enc': batch_n_symb_enc,
|
165 |
+
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
return feed_dict
|
170 |
+
|
171 |
+
|
172 |
+
class TestDataset(Dataset):
|
173 |
+
"""
|
174 |
+
data format: {
|
175 |
+
"query": query,
|
176 |
+
"pred_dict": {node_id: score},
|
177 |
+
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
|
178 |
+
"text_emb_dict": {node_id: text_emb},
|
179 |
+
"ans_ids": [],
|
180 |
+
}
|
181 |
+
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, saved_data):
|
185 |
+
|
186 |
+
print(f"Start processing test dataset...")
|
187 |
+
self.text2emb_dict = saved_data['text2emb_dict']
|
188 |
+
self.data = saved_data['data']
|
189 |
+
|
190 |
+
self.text_emb_matrix = list(self.text2emb_dict.values())
|
191 |
+
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
|
192 |
+
|
193 |
+
# make the mapping between the key of text2emb_dict and the index of text_emb_matrix
|
194 |
+
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
|
195 |
+
|
196 |
+
print(f"Complete data preparation: {len(self.data)}")
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
def __len__(self):
|
202 |
+
return len(self.data)
|
203 |
+
|
204 |
+
def __getitem__(self, idx):
|
205 |
+
|
206 |
+
# sort the pred_dict by the score
|
207 |
+
pred_dict = self.data[idx]['pred_dict']
|
208 |
+
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
|
209 |
+
# get the top 50 candidates
|
210 |
+
sorted_ids = sorted_ids[:50]
|
211 |
+
# get the score vector
|
212 |
+
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
|
213 |
+
# get the symb_enc_dict
|
214 |
+
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
|
215 |
+
# change from the str to index
|
216 |
+
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
|
217 |
+
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
|
218 |
+
|
219 |
+
return self.data[idx]
|
220 |
+
|
221 |
+
def collate_batch(self, batch):
|
222 |
+
|
223 |
+
# q
|
224 |
+
batch_q = [batch[i]['query'] for i in range(len(batch))]
|
225 |
+
q_text = batch_q
|
226 |
+
|
227 |
+
# c
|
228 |
+
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
|
229 |
+
batch_c = torch.tensor(batch_c)
|
230 |
+
# print(f"111, {batch_c.shape}")
|
231 |
+
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
|
232 |
+
c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
|
233 |
+
|
234 |
+
|
235 |
+
# print(f"222, {c_vector.shape}")
|
236 |
+
# c_text_emb
|
237 |
+
# c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
|
238 |
+
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
|
239 |
+
c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
|
240 |
+
|
241 |
+
# c_symb_enc
|
242 |
+
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
|
243 |
+
c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
|
244 |
+
|
245 |
+
|
246 |
+
# ans_ids
|
247 |
+
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
|
248 |
+
|
249 |
+
# pred_ids
|
250 |
+
pred_ids = batch_c.tolist()
|
251 |
+
|
252 |
+
|
253 |
+
# Create a dictionary for the batch
|
254 |
+
feed_dict = {
|
255 |
+
'query': q_text,
|
256 |
+
'c_score_vector': c_score_vector,
|
257 |
+
'c_text_emb': c_text_emb,
|
258 |
+
'c_symb_enc': c_symb_enc,
|
259 |
+
'ans_ids': ans_ids,
|
260 |
+
'pred_ids': pred_ids
|
261 |
+
|
262 |
+
}
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
return feed_dict
|
267 |
+
|
268 |
+
|
269 |
+
# ******* loss function ********
|
270 |
+
def loss_fn(scores_pos, scores_neg):
|
271 |
+
|
272 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
273 |
+
|
274 |
+
# Combine scores
|
275 |
+
scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
|
276 |
+
# print(f"scores: {scores.shape}")
|
277 |
+
|
278 |
+
# Create target
|
279 |
+
target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
|
280 |
+
|
281 |
+
# Compute loss
|
282 |
+
loss = loss_fct(scores, target)
|
283 |
+
|
284 |
+
return loss
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
def batch_evaluator(skb, scores_cand, ans_ids, batch):
|
289 |
+
|
290 |
+
results = {}
|
291 |
+
|
292 |
+
# **** batch wise evaluation ****
|
293 |
+
# evaluate
|
294 |
+
candidates_ids = skb.candidate_ids
|
295 |
+
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
|
296 |
+
|
297 |
+
|
298 |
+
# initialize the pred_matrix
|
299 |
+
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
|
300 |
+
|
301 |
+
|
302 |
+
# get the index of each pred_ids
|
303 |
+
# flatten the pred_ids
|
304 |
+
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
|
305 |
+
|
306 |
+
|
307 |
+
# get the index of each pred_ids
|
308 |
+
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
|
309 |
+
|
310 |
+
|
311 |
+
# reshape the pred_idx
|
312 |
+
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
|
313 |
+
|
314 |
+
# move pred_matrix to the device
|
315 |
+
pred_matrix = pred_matrix.to(scores_cand.device)
|
316 |
+
|
317 |
+
# advanced indexing
|
318 |
+
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
# Create a mapping from candidate IDs to their indices for faster lookup
|
323 |
+
|
324 |
+
|
325 |
+
# Flatten ans_ids to a single list and map them to indices
|
326 |
+
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
|
327 |
+
|
328 |
+
# Create the row indices for ans_matrix corresponding to the answers
|
329 |
+
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
|
330 |
+
|
331 |
+
# Create the answer matrix
|
332 |
+
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
|
333 |
+
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
# batch computing hit1
|
340 |
+
# find the index of the max score
|
341 |
+
max_score, max_idx = torch.max(pred_matrix, dim=1)
|
342 |
+
# check the label of the max idx
|
343 |
+
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
|
344 |
+
hit1_list = batch_hit1.tolist()
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
# batch computing hit@5
|
350 |
+
_, top5_idx = torch.topk(pred_matrix, 5, dim=1)
|
351 |
+
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
|
352 |
+
|
353 |
+
# max with each row
|
354 |
+
batch_hit5 = torch.max(batch_hit5, dim=1)[0]
|
355 |
+
hit5_list = batch_hit5.tolist()
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
# batch computing recall@20
|
360 |
+
_, top20_idx = torch.topk(pred_matrix, 20, dim=1)
|
361 |
+
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
|
362 |
+
# sum with each row
|
363 |
+
batch_recall20 = torch.sum(batch_recall20, dim=1)
|
364 |
+
# divide by the sum of the ans_matrix along the row
|
365 |
+
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
|
366 |
+
recall20_list = batch_recall20.tolist()
|
367 |
+
|
368 |
+
|
369 |
+
|
370 |
+
|
371 |
+
# batch computing mrr
|
372 |
+
# find the highest rank of the answer
|
373 |
+
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
|
374 |
+
# query the answer matrix with the rank_idx
|
375 |
+
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
|
376 |
+
# find the first rank of the answer
|
377 |
+
batch_mrr = torch.argmax(batch_mrr, dim=1)
|
378 |
+
# add 1 to the rank
|
379 |
+
batch_mrr += 1
|
380 |
+
# divide by the rank
|
381 |
+
batch_mrr = 1 / batch_mrr.float()
|
382 |
+
mrr_list = batch_mrr.tolist()
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
results['hit@1'] = hit1_list
|
388 |
+
results['hit@5'] = hit5_list
|
389 |
+
results['recall@20'] = recall20_list
|
390 |
+
results['mrr'] = mrr_list
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
return results
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
# ***** evaluate *****
|
402 |
+
@torch.no_grad()
|
403 |
+
def evaluate(reranker, test_loader):
|
404 |
+
|
405 |
+
|
406 |
+
reranker.eval()
|
407 |
+
|
408 |
+
all_results = {
|
409 |
+
"hit@1": [],
|
410 |
+
"hit@5": [],
|
411 |
+
"recall@20": [],
|
412 |
+
"mrr": []
|
413 |
+
}
|
414 |
+
avg_results = {
|
415 |
+
"hit@1": 0,
|
416 |
+
"hit@5": 0,
|
417 |
+
"recall@20": 0,
|
418 |
+
"mrr": 0
|
419 |
+
}
|
420 |
+
|
421 |
+
|
422 |
+
# save the scores and ans_ids, and pred_ids
|
423 |
+
pred_list = []
|
424 |
+
scores_cand_list = []
|
425 |
+
ans_ids_list = []
|
426 |
+
# use tqdm to show the progress
|
427 |
+
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
|
428 |
+
batch = move_to_cuda(batch)
|
429 |
+
|
430 |
+
# Check if the model is wrapped in DataParallel
|
431 |
+
if isinstance(reranker, nn.DataParallel):
|
432 |
+
scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
|
433 |
+
else:
|
434 |
+
scores_cand = reranker.eval_batch(batch)
|
435 |
+
|
436 |
+
|
437 |
+
# ans_ids
|
438 |
+
ans_ids = batch['ans_ids']
|
439 |
+
|
440 |
+
results = batch_evaluator(skb, scores_cand, ans_ids, batch)
|
441 |
+
|
442 |
+
|
443 |
+
for key in results.keys():
|
444 |
+
all_results[key].extend(results[key])
|
445 |
+
|
446 |
+
# save the scores and ans_ids, and pred_ids
|
447 |
+
pred_list.extend(batch['pred_ids'])
|
448 |
+
scores_cand_list.extend(scores_cand.cpu().tolist())
|
449 |
+
ans_ids_list.extend(ans_ids)
|
450 |
+
|
451 |
+
|
452 |
+
|
453 |
+
for key in avg_results.keys():
|
454 |
+
avg_results[key] = np.mean(all_results[key])
|
455 |
+
|
456 |
+
print(f"Results: {avg_results}")
|
457 |
+
|
458 |
+
|
459 |
+
|
460 |
+
return avg_results
|
461 |
+
|
462 |
+
|
463 |
+
# ***** train *****
|
464 |
+
def main(train_data, val_data, test_data, skb, dataset_name, args):
|
465 |
+
|
466 |
+
|
467 |
+
epochs = args.epochs
|
468 |
+
device = args.device
|
469 |
+
|
470 |
+
train_size = args.train_batch_size
|
471 |
+
test_size = 64
|
472 |
+
|
473 |
+
train_dataset = TrainDataset(train_data)
|
474 |
+
train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
|
475 |
+
|
476 |
+
test_dataset = TestDataset(test_data)
|
477 |
+
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
|
478 |
+
|
479 |
+
val_dataset = TestDataset(val_data)
|
480 |
+
val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
|
481 |
+
|
482 |
+
|
483 |
+
# ***** Model *****
|
484 |
+
reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
|
485 |
+
save_dir = f"./data/checkpoints/{dataset_name}/path"
|
486 |
+
os.makedirs(save_dir, exist_ok=True)
|
487 |
+
|
488 |
+
reranker.to(device)
|
489 |
+
# # parallel processing
|
490 |
+
reranker = nn.DataParallel(reranker)
|
491 |
+
|
492 |
+
|
493 |
+
optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
|
494 |
+
best_val_hit1 = float('-inf')
|
495 |
+
|
496 |
+
|
497 |
+
val_results = evaluate(reranker, val_loader)
|
498 |
+
print(f"Val evaluation")
|
499 |
+
print(val_results)
|
500 |
+
|
501 |
+
|
502 |
+
test_results = evaluate(reranker, test_loader)
|
503 |
+
print(f"Test evaluation")
|
504 |
+
print(test_results)
|
505 |
+
|
506 |
+
# log both val and test results
|
507 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
508 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
|
509 |
+
|
510 |
+
best_test_results = {}
|
511 |
+
for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
|
512 |
+
total_loss = 0.0
|
513 |
+
reranker.train()
|
514 |
+
count = 0
|
515 |
+
total_instances = 0
|
516 |
+
|
517 |
+
for batch in tqdm(train_loader):
|
518 |
+
# print(batch)
|
519 |
+
batch = move_to_cuda(batch)
|
520 |
+
# print(batch)
|
521 |
+
|
522 |
+
scores_pos, scores_neg = reranker(batch)
|
523 |
+
|
524 |
+
# batch_loss = pairwise_loss(scores_pos, scores_neg)
|
525 |
+
batch_loss = loss_fn(scores_pos, scores_neg)
|
526 |
+
|
527 |
+
# clear optimizer
|
528 |
+
optimizer.zero_grad()
|
529 |
+
batch_loss.backward()
|
530 |
+
optimizer.step()
|
531 |
+
|
532 |
+
# total_loss += batch_loss.item()
|
533 |
+
count += 1
|
534 |
+
# compute the average loss
|
535 |
+
total_instances += scores_pos.shape[0]
|
536 |
+
total_loss += batch_loss.item()
|
537 |
+
|
538 |
+
|
539 |
+
train_loss = total_loss / total_instances
|
540 |
+
|
541 |
+
print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
|
542 |
+
|
543 |
+
|
544 |
+
|
545 |
+
val_results = evaluate(reranker, val_loader)
|
546 |
+
print(f"Val evaluation")
|
547 |
+
print(val_results)
|
548 |
+
|
549 |
+
|
550 |
+
test_results = evaluate(reranker, test_loader)
|
551 |
+
print(f"Test evaluation")
|
552 |
+
print(test_results)
|
553 |
+
|
554 |
+
# log both val and test results
|
555 |
+
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
|
556 |
+
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
|
557 |
+
'train_loss': train_loss})
|
558 |
+
|
559 |
+
|
560 |
+
# save the best model when val hit1 is the highest
|
561 |
+
hit1 = val_results['hit@1']
|
562 |
+
if best_val_hit1 < hit1:
|
563 |
+
best_val_hit1 = hit1
|
564 |
+
|
565 |
+
save_path = f"{save_dir}/best_{best_val_hit1}.pth"
|
566 |
+
|
567 |
+
if isinstance(reranker, nn.DataParallel):
|
568 |
+
torch.save(reranker.module.state_dict(), save_path)
|
569 |
+
else:
|
570 |
+
torch.save(reranker.state_dict(), save_path)
|
571 |
+
print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
|
572 |
+
|
573 |
+
args.checkpoint_path = save_path
|
574 |
+
best_test_results = test_results
|
575 |
+
|
576 |
+
|
577 |
+
|
578 |
+
# save last epoch checkopint
|
579 |
+
save_path = f"{save_dir}/last_{hit1}.pth"
|
580 |
+
if isinstance(reranker, nn.DataParallel):
|
581 |
+
torch.save(reranker.module.state_dict(), save_path)
|
582 |
+
else:
|
583 |
+
torch.save(reranker.state_dict(), save_path)
|
584 |
+
print(f"Final checkpoint saved at {save_path}")
|
585 |
+
|
586 |
+
|
587 |
+
|
588 |
+
# ***** save the results *****
|
589 |
+
results = []
|
590 |
+
results.append(
|
591 |
+
{
|
592 |
+
"config": vars(args),
|
593 |
+
"test_results": best_test_results
|
594 |
+
}
|
595 |
+
)
|
596 |
+
# save the results to json
|
597 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
598 |
+
output_dir = f"./data/outputs/{dataset_name}"
|
599 |
+
os.makedirs(output_dir, exist_ok=True)
|
600 |
+
with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
|
601 |
+
json.dump(results, f, indent=4)
|
602 |
+
|
603 |
+
print(best_test_results)
|
604 |
+
|
605 |
+
def get_concat_num(combo):
|
606 |
+
"""
|
607 |
+
Determine the value of concat_num based on the combination of embeddings.
|
608 |
+
- score_vec adds +1
|
609 |
+
- text_emb adds +1
|
610 |
+
- symb_enc adds +3
|
611 |
+
"""
|
612 |
+
concat_num = 0
|
613 |
+
if combo.get("score_vec", False): # If score_vec is True
|
614 |
+
concat_num += 1
|
615 |
+
if combo.get("text_emb", False): # If text_emb is True
|
616 |
+
concat_num += 1
|
617 |
+
if combo.get("symb_enc", False): # If symb_enc is True
|
618 |
+
concat_num += 3
|
619 |
+
|
620 |
+
|
621 |
+
return concat_num
|
622 |
+
|
623 |
+
|
624 |
+
|
625 |
+
def parse_args():
|
626 |
+
|
627 |
+
parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
|
628 |
+
|
629 |
+
# Add arguments for model configurations
|
630 |
+
parser.add_argument("--train_batch_size", type=int, default=256, help="Batch size for training or evaluation.")
|
631 |
+
parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate for optimizer.")
|
632 |
+
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
|
633 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
|
634 |
+
|
635 |
+
# Add arguments for the dataset
|
636 |
+
parser.add_argument("--dataset_name", type=str, default="prime", help="Name of the dataset to use.")
|
637 |
+
# paths
|
638 |
+
parser.add_argument("--train_path", type=str, default=f"../prime_train.pkl", help="Path to the training data.")
|
639 |
+
parser.add_argument("--test_path", type=str, default=f"../prime_test.pkl", help="Path to the test data.")
|
640 |
+
parser.add_argument("--val_path", type=str, default=f"../prime_val.pkl", help="Path to the validation data.")
|
641 |
+
|
642 |
+
# add concat_num
|
643 |
+
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
|
644 |
+
|
645 |
+
# checkpoint save path
|
646 |
+
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
|
647 |
+
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
|
648 |
+
|
649 |
+
|
650 |
+
# Parse the base arguments
|
651 |
+
args = parser.parse_args()
|
652 |
+
return args
|
653 |
+
|
654 |
+
|
655 |
+
if __name__ == "__main__":
|
656 |
+
|
657 |
+
|
658 |
+
base_args = parse_args()
|
659 |
+
|
660 |
+
dataset_name = base_args.dataset_name
|
661 |
+
train_path = base_args.train_path
|
662 |
+
test_path = base_args.test_path
|
663 |
+
val_path = base_args.val_path
|
664 |
+
|
665 |
+
|
666 |
+
with open(test_path, "rb") as f:
|
667 |
+
test_data = pkl.load(f)
|
668 |
+
|
669 |
+
with open(train_path, "rb") as f:
|
670 |
+
train_data = pkl.load(f)
|
671 |
+
|
672 |
+
with open(val_path, "rb") as f:
|
673 |
+
val_data = pkl.load(f)
|
674 |
+
|
675 |
+
# load skb
|
676 |
+
skb = load_skb(dataset_name)
|
677 |
+
|
678 |
+
# set all
|
679 |
+
combo = {
|
680 |
+
"text_emb": True,
|
681 |
+
"score_vec": True,
|
682 |
+
"symb_enc": True
|
683 |
+
}
|
684 |
+
concat_num = get_concat_num(combo)
|
685 |
+
|
686 |
+
wandb.init(project=f'Reranking-{dataset_name}', name=f"path")
|
687 |
+
args = argparse.Namespace(**vars(base_args), **combo)
|
688 |
+
args.concat_num = concat_num
|
689 |
+
|
690 |
+
main(train_data, val_data, test_data, skb, dataset_name, args)
|
691 |
+
|
Reranking/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
# Get the absolute path of the current script
|
4 |
+
current_file = Path(__file__).resolve()
|
5 |
+
project_root = current_file.parents[1]
|
6 |
+
# Add the project root to the system path
|
7 |
+
sys.path.append(str(project_root))
|
8 |
+
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
import os
|
12 |
+
from stark_qa.evaluator import Evaluator
|
13 |
+
import torch.nn as nn
|
14 |
+
from typing import Any, Union, List, Dict
|
15 |
+
|
16 |
+
|
17 |
+
class ModelForSTaRKQA(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, skb, query_emb_dir='.'):
|
20 |
+
"""
|
21 |
+
Initializes the model with the given knowledge base.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
skb: Knowledge base containing candidate information.
|
25 |
+
"""
|
26 |
+
super(ModelForSTaRKQA, self).__init__()
|
27 |
+
self.skb = skb
|
28 |
+
|
29 |
+
self.candidate_ids = skb.candidate_ids
|
30 |
+
self.evaluator = Evaluator(self.candidate_ids)
|
31 |
+
|
32 |
+
def evaluate(self,
|
33 |
+
pred_dict: Dict[int, float],
|
34 |
+
answer_ids: Union[torch.LongTensor, List[Any]],
|
35 |
+
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
|
36 |
+
**kwargs: Any) -> Dict[str, float]:
|
37 |
+
"""
|
38 |
+
Evaluates the predictions using the specified metrics.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
pred_dict (Dict[int, float]): Predicted answer ids or scores.
|
42 |
+
answer_ids (torch.LongTensor): Ground truth answer ids.
|
43 |
+
metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
|
44 |
+
'precision@k', 'map@k', 'ndcg@k'.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Dict[str, float]: A dictionary of evaluation metrics.
|
48 |
+
"""
|
49 |
+
return self.evaluator(pred_dict, answer_ids, metrics)
|
50 |
+
|
51 |
+
def evaluate_batch(self,
|
52 |
+
pred_ids: List[int],
|
53 |
+
pred: torch.Tensor,
|
54 |
+
answer_ids: Union[torch.LongTensor, List[Any]],
|
55 |
+
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
|
56 |
+
**kwargs: Any) -> Dict[str, float]:
|
57 |
+
return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
|
58 |
+
|
59 |
+
|
60 |
+
def seed_everything(seed=0):
|
61 |
+
random.seed(seed)
|
62 |
+
torch.manual_seed(seed)
|
63 |
+
torch.cuda.manual_seed_all(seed)
|
64 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
65 |
+
torch.backends.cudnn.deterministic = True
|
66 |
+
torch.backends.cudnn.benchmark = False
|
67 |
+
|
68 |
+
|
69 |
+
def move_to_cuda(sample):
|
70 |
+
if len(sample) == 0:
|
71 |
+
return {}
|
72 |
+
|
73 |
+
def _move_to_cuda(maybe_tensor):
|
74 |
+
if torch.is_tensor(maybe_tensor):
|
75 |
+
return maybe_tensor.cuda()
|
76 |
+
elif isinstance(maybe_tensor, dict):
|
77 |
+
return {
|
78 |
+
key: _move_to_cuda(value)
|
79 |
+
for key, value in maybe_tensor.items()
|
80 |
+
}
|
81 |
+
# elif isinstance(maybe_tensor, list):
|
82 |
+
# return [_move_to_cuda(x) for x in maybe_tensor]
|
83 |
+
else:
|
84 |
+
return maybe_tensor
|
85 |
+
|
86 |
+
return _move_to_cuda(sample)
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
print("Testing Utils")
|