File size: 788 Bytes
f0366e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import numpy as np
from utils import DEVICE


class FeatureRefLoader:
    def __init__(self):
        print("Feature Ref Loader init")

    # TODO: The format of feature
    def load(self, feature_ref_file_name, num_ref=5000):
        print("Feature Ref Loader load")
        load_ref_data = torch.load(feature_ref_file_name, map_location=DEVICE)  # cpu
        load_ref_data = load_ref_data.to(DEVICE)
        feature_ref = load_ref_data[np.random.permutation(load_ref_data.shape[0])][
            :num_ref
        ].to(DEVICE)
        return feature_ref


feature_two_sample_tester_ref = FeatureRefLoader().load("./feature_ref_for_test.pt")
feature_hwt_ref = FeatureRefLoader().load("./feature_ref_HWT.pt")
feature_mgt_ref = FeatureRefLoader().load("./feature_ref_MGT.pt")