import torch import nltk from roberta_model_loader import roberta_model from feature_ref_loader import feature_mgt_ref, feature_hwt_ref from meta_train import net from regression_model_loader import regression_model from MMD import MMD_3_Sample_Test from utils import DEVICE, FeatureExtractor, HWT, MGT class RelativeTester: def __init__(self): print("Relative Tester init") self.feature_extractor = FeatureExtractor(roberta_model, net) def sents_split(self, text): nltk.download("punkt", quiet=True) nltk.download("punkt_tab", quiet=True) sents = nltk.sent_tokenize(text) return [sent for sent in sents if 5 < len(sent.split())] def test(self, input_text, threshold=0.2, round=20): print("Relative Tester test") # Split the input text sents = self.sents_split(input_text) print("DEBUG: sents:", len(sents)) # Extract features feature_for_sents = self.feature_extractor.process_sents(sents, False) if len(feature_for_sents) <= 1: # print("DEBUG: tooshort") return "Too short to test! Please input more than 2 sentences." # Cutoff the features min_len = min( len(feature_for_sents), len(feature_hwt_ref), len(feature_mgt_ref), ) # Calculate MMD h_u_list = [] p_value_list = [] t_list = [] for i in range(round): feature_for_sents_sample = feature_for_sents[ torch.randperm(len(feature_for_sents))[:min_len] ] feature_hwt_ref_sample = feature_hwt_ref[ torch.randperm(len(feature_hwt_ref))[:min_len] ] feature_mgt_ref_sample = feature_mgt_ref[ torch.randperm(len(feature_mgt_ref))[:min_len] ] h_u, p_value, t, *rest = MMD_3_Sample_Test( net.net(feature_for_sents_sample), net.net(feature_hwt_ref_sample), net.net(feature_mgt_ref_sample), feature_for_sents_sample.view(feature_for_sents_sample.shape[0], -1), feature_hwt_ref_sample.view(feature_hwt_ref_sample.shape[0], -1), feature_mgt_ref_sample.view(feature_mgt_ref_sample.shape[0], -1), net.sigma, net.sigma0_u, net.ep, 0.05, ) h_u_list.append(h_u) p_value_list.append(p_value) t_list.append(t) power = sum(h_u_list) / len(h_u_list) print("DEBUG: power:", power) print("DEBUG: power list:", h_u_list) # Return the result return "Human" if power <= threshold else "AI" relative_tester = RelativeTester()