Spaces:
Running
Running
Create relative_tester.py
Browse files- relative_tester.py +77 -0
relative_tester.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import nltk
|
3 |
+
from roberta_model_loader import roberta_model
|
4 |
+
from feature_ref_loader import feature_mgt_ref, feature_hwt_ref
|
5 |
+
from meta_train import net
|
6 |
+
from regression_model_loader import regression_model
|
7 |
+
from MMD import MMD_3_Sample_Test
|
8 |
+
from utils import DEVICE, FeatureExtractor, HWT, MGT
|
9 |
+
|
10 |
+
|
11 |
+
class RelativeTester:
|
12 |
+
def __init__(self):
|
13 |
+
print("Relative Tester init")
|
14 |
+
self.feature_extractor = FeatureExtractor(roberta_model, net)
|
15 |
+
|
16 |
+
def sents_split(self, text):
|
17 |
+
nltk.download("punkt", quiet=True)
|
18 |
+
nltk.download("punkt_tab", quiet=True)
|
19 |
+
sents = nltk.sent_tokenize(text)
|
20 |
+
return [sent for sent in sents if 5 < len(sent.split())]
|
21 |
+
|
22 |
+
def test(self, input_text, threshold=0.2, round=20):
|
23 |
+
print("Relative Tester test")
|
24 |
+
# Split the input text
|
25 |
+
sents = self.sents_split(input_text)
|
26 |
+
print("DEBUG: sents:", len(sents))
|
27 |
+
# Extract features
|
28 |
+
feature_for_sents = self.feature_extractor.process_sents(sents, False)
|
29 |
+
if len(feature_for_sents) <= 1:
|
30 |
+
# print("DEBUG: tooshort")
|
31 |
+
return "Too short to test! Please input more than 2 sentences."
|
32 |
+
# Cutoff the features
|
33 |
+
min_len = min(
|
34 |
+
len(feature_for_sents),
|
35 |
+
len(feature_hwt_ref),
|
36 |
+
len(feature_mgt_ref),
|
37 |
+
)
|
38 |
+
# Calculate MMD
|
39 |
+
h_u_list = []
|
40 |
+
p_value_list = []
|
41 |
+
t_list = []
|
42 |
+
|
43 |
+
for i in range(round):
|
44 |
+
feature_for_sents_sample = feature_for_sents[
|
45 |
+
torch.randperm(len(feature_for_sents))[:min_len]
|
46 |
+
]
|
47 |
+
feature_hwt_ref_sample = feature_hwt_ref[
|
48 |
+
torch.randperm(len(feature_hwt_ref))[:min_len]
|
49 |
+
]
|
50 |
+
feature_mgt_ref_sample = feature_mgt_ref[
|
51 |
+
torch.randperm(len(feature_mgt_ref))[:min_len]
|
52 |
+
]
|
53 |
+
h_u, p_value, t, *rest = MMD_3_Sample_Test(
|
54 |
+
net.net(feature_for_sents_sample),
|
55 |
+
net.net(feature_hwt_ref_sample),
|
56 |
+
net.net(feature_mgt_ref_sample),
|
57 |
+
feature_for_sents_sample.view(feature_for_sents_sample.shape[0], -1),
|
58 |
+
feature_hwt_ref_sample.view(feature_hwt_ref_sample.shape[0], -1),
|
59 |
+
feature_mgt_ref_sample.view(feature_mgt_ref_sample.shape[0], -1),
|
60 |
+
net.sigma,
|
61 |
+
net.sigma0_u,
|
62 |
+
net.ep,
|
63 |
+
0.05,
|
64 |
+
)
|
65 |
+
|
66 |
+
h_u_list.append(h_u)
|
67 |
+
p_value_list.append(p_value)
|
68 |
+
t_list.append(t)
|
69 |
+
|
70 |
+
power = sum(h_u_list) / len(h_u_list)
|
71 |
+
print("DEBUG: power:", power)
|
72 |
+
print("DEBUG: power list:", h_u_list)
|
73 |
+
# Return the result
|
74 |
+
return "Human" if power <= threshold else "AI"
|
75 |
+
|
76 |
+
|
77 |
+
relative_tester = RelativeTester()
|