jer233 commited on
Commit
699e956
·
verified ·
1 Parent(s): e1a5b54

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +56 -0
utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+
5
+ gpu_using = False
6
+ DEVICE = torch.device("cpu")
7
+ if gpu_using:
8
+ DEVICE = torch.device("cuda:0")
9
+
10
+ HWT = "HWT"
11
+ MGT = "MGT"
12
+
13
+
14
+ def init_random_seeds():
15
+ print("Init random seeds")
16
+ random.seed(0)
17
+ np.random.seed(0)
18
+ torch.manual_seed(0)
19
+ torch.cuda.manual_seed(0)
20
+ torch.cuda.manual_seed_all(0)
21
+ torch.backends.cudnn.benchmark = False
22
+ torch.backends.cudnn.deterministic = True
23
+
24
+
25
+ class FeatureExtractor:
26
+ def __init__(self, model, net=None):
27
+ self.model = model # TODO: support different models
28
+ self.net = net
29
+
30
+ def process(self, text, net_required=True):
31
+ # Tokenize
32
+ tokens = self.model.tokenizer(
33
+ [text],
34
+ padding="max_length",
35
+ truncation=True,
36
+ max_length=100,
37
+ return_tensors="pt",
38
+ ).to(DEVICE)
39
+ # Predict
40
+ outputs = self.model.model(**tokens)
41
+ # Get the feature for input text
42
+ attention_mask = tokens["attention_mask"].unsqueeze(-1)
43
+ hidden_states_masked = (
44
+ outputs.last_hidden_state * attention_mask
45
+ ) # Ignore the padding tokens
46
+ if net_required and self.net is not None:
47
+ feature = self.net.net(hidden_states_masked)
48
+ return feature
49
+ else:
50
+ return hidden_states_masked
51
+
52
+ def process_sents(self, sents, net_required=True):
53
+ features = []
54
+ for sent in sents:
55
+ features.append(self.process(sent, net_required))
56
+ return torch.cat(features, dim=0)