jer233 commited on
Commit
5df855d
·
verified ·
1 Parent(s): 9275958

Create relative_tester.py

Browse files
Files changed (1) hide show
  1. relative_tester.py +77 -0
relative_tester.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ from roberta_model_loader import roberta_model
4
+ from feature_ref_loader import feature_mgt_ref, feature_hwt_ref
5
+ from meta_train import net
6
+ from regression_model_loader import regression_model
7
+ from MMD import MMD_3_Sample_Test
8
+ from utils import DEVICE, FeatureExtractor, HWT, MGT
9
+
10
+
11
+ class RelativeTester:
12
+ def __init__(self):
13
+ print("Relative Tester init")
14
+ self.feature_extractor = FeatureExtractor(roberta_model, net)
15
+
16
+ def sents_split(self, text):
17
+ nltk.download("punkt", quiet=True)
18
+ nltk.download("punkt_tab", quiet=True)
19
+ sents = nltk.sent_tokenize(text)
20
+ return [sent for sent in sents if 5 < len(sent.split())]
21
+
22
+ def test(self, input_text, threshold=0.2, round=20):
23
+ print("Relative Tester test")
24
+ # Split the input text
25
+ sents = self.sents_split(input_text)
26
+ print("DEBUG: sents:", len(sents))
27
+ # Extract features
28
+ feature_for_sents = self.feature_extractor.process_sents(sents, False)
29
+ if len(feature_for_sents) <= 1:
30
+ # print("DEBUG: tooshort")
31
+ return "Too short to test! Please input more than 2 sentences."
32
+ # Cutoff the features
33
+ min_len = min(
34
+ len(feature_for_sents),
35
+ len(feature_hwt_ref),
36
+ len(feature_mgt_ref),
37
+ )
38
+ # Calculate MMD
39
+ h_u_list = []
40
+ p_value_list = []
41
+ t_list = []
42
+
43
+ for i in range(round):
44
+ feature_for_sents_sample = feature_for_sents[
45
+ torch.randperm(len(feature_for_sents))[:min_len]
46
+ ]
47
+ feature_hwt_ref_sample = feature_hwt_ref[
48
+ torch.randperm(len(feature_hwt_ref))[:min_len]
49
+ ]
50
+ feature_mgt_ref_sample = feature_mgt_ref[
51
+ torch.randperm(len(feature_mgt_ref))[:min_len]
52
+ ]
53
+ h_u, p_value, t, *rest = MMD_3_Sample_Test(
54
+ net.net(feature_for_sents_sample),
55
+ net.net(feature_hwt_ref_sample),
56
+ net.net(feature_mgt_ref_sample),
57
+ feature_for_sents_sample.view(feature_for_sents_sample.shape[0], -1),
58
+ feature_hwt_ref_sample.view(feature_hwt_ref_sample.shape[0], -1),
59
+ feature_mgt_ref_sample.view(feature_mgt_ref_sample.shape[0], -1),
60
+ net.sigma,
61
+ net.sigma0_u,
62
+ net.ep,
63
+ 0.05,
64
+ )
65
+
66
+ h_u_list.append(h_u)
67
+ p_value_list.append(p_value)
68
+ t_list.append(t)
69
+
70
+ power = sum(h_u_list) / len(h_u_list)
71
+ print("DEBUG: power:", power)
72
+ print("DEBUG: power list:", h_u_list)
73
+ # Return the result
74
+ return "Human" if power <= threshold else "AI"
75
+
76
+
77
+ relative_tester = RelativeTester()