File size: 3,658 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from Reasoning.text_retrievers.bm25 import BM25
from Reasoning.text_retrievers.ada import Ada
from Reasoning.text_retrievers.contriever import Contriever


def combine_dicts(dicts_list, pred_dict):
    if len(dicts_list) == 1:
        return dicts_list[0]
    combined_dict = {}
    
    for d in dicts_list:
        for key, value in d.items():
            if key in combined_dict:
                # for route dict, the values are lists, keep the longest list
                if len(value) > len(combined_dict[key]):
                    combined_dict[key] = value
            else:
                combined_dict[key] = value

    # if the two reasoning paths have intersection, only keep the keys in pred_dict 
    combined_dict = {key: combined_dict[key] for key in pred_dict.keys()}
    
    
    return combined_dict
    
def fix_length(paths_dict):
    max_length = 3
    new_paths_dict = {}
    
    for key, value in paths_dict.items():
        if len(value) > max_length:
            value = value[-max_length:]
        if len(value) < max_length:
            # padding with -1 at the beginning
            value = [-1] * (max_length - len(value)) + value
        new_paths_dict[key] = value
        
    return new_paths_dict
            
    

def parse_metapath(metapath):
    """
        input: metapath: "paper -> author -> paper <- paper"
        output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']]
    """
    
    def parse(remain_list, direction):
        """
            input: remain_list: ["paper", "->", "author", "->", "paper", "<-", "paper"]
                direction: "->"
            output: route: ["paper", "author", "paper"]
                    remain_list: ["paper", "<-", "paper"]
        """
        route = []
        i = 0
        while i < len(remain_list)-1 and remain_list[i+1] == direction:
            route.append(remain_list[i])
            i += 2
        route.append(remain_list[i])
        
        if direction == "<-":
            route.reverse()    
            
        remain_list = None if len(remain_list) == i+1 else remain_list[i:]
            
        return route, remain_list
    
    
    remain_list = metapath.split(' ')
    # print(f"111, {remain_list}")
    
    if len(remain_list) == 1: # single node
        return [remain_list]
    
    routes = []
    while remain_list is not None:
        if remain_list[1] == "<-":
            route, remain_list = parse(remain_list, "<-")
        
        elif remain_list[1] == "->":
            route, remain_list = parse(remain_list, "->")
            
        else: 
            # raise ValueError(f"Invalid metapath: {metapath}")
            return None
        
        routes.append(route)
            
    return routes 


def get_text_retriever(dataset_name, retriever_name, skb, **kwargs):
    if retriever_name == "bm25":
        return BM25(skb, dataset_name)
    elif retriever_name == "ada":
        return Ada(skb, dataset_name, kwargs.get("device", 'cuda'))
    elif retriever_name == "contriever":
        return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
    else:
        raise ValueError(f"Invalid retriever name: {retriever_name}")


def get_scorer(dataset_name, scorer_name, skb, **kwargs):
    if scorer_name == "bm25":
        return BM25(skb, dataset_name)
    elif scorer_name == "ada":
        return Ada(skb, dataset_name, kwargs.get("device",'cuda'))
    elif scorer_name == "contriever":
        return Contriever(skb, dataset_name, kwargs.get("device", 'cuda'))
    else:
        raise ValueError(f"Invalid scorer name: {scorer_name}")
    

if __name__ == "__main__":
    print(f"Test utils")