jer233 commited on
Commit
d97734e
·
verified ·
1 Parent(s): e6eb9e0

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +123 -0
demo.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from utils_MMD import extract_features # Adjust the import path
4
+ from MMD_calculate import mmd_two_sample_baseline # Adjust the import path
5
+
6
+ MINIMUM_TOKENS = 64
7
+ THRESHOLD = 0.5 # Threshold for classification
8
+
9
+ def count_tokens(text, tokenizer):
10
+ """
11
+ Counts the number of tokens in the text using the provided tokenizer.
12
+ """
13
+ return len(tokenizer(text).input_ids)
14
+
15
+ def run_test_power(model_name, real_text, generated_text, N=10):
16
+ """
17
+ Runs the test power calculation for provided real and generated texts.
18
+
19
+ Args:
20
+ model_name (str): Hugging Face model name.
21
+ real_text (str): Example real text for comparison.
22
+ generated_text (str): The input text to classify.
23
+ N (int): Number of repetitions for MMD calculation.
24
+
25
+ Returns:
26
+ str: "Prediction: Human" or "Prediction: AI".
27
+ """
28
+ # Load tokenizer and model
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModel.from_pretrained(model_name).cuda()
31
+ model.eval()
32
+
33
+ # Ensure minimum token length
34
+ if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS:
35
+ return "Too short length. Need a minimum of 64 tokens to calculate Test Power."
36
+
37
+ # Extract features
38
+ fea_real_ls = extract_features([real_text], tokenizer, model)
39
+ fea_generated_ls = extract_features([generated_text], tokenizer, model)
40
+
41
+ # Calculate test power list
42
+ test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=N)
43
+
44
+ # Compute the average test power value
45
+ power_test_value = sum(test_power_ls) / len(test_power_ls)
46
+
47
+ # Classify the text
48
+ if power_test_value < THRESHOLD:
49
+ return "Prediction: Human"
50
+ else:
51
+ return "Prediction: AI"
52
+
53
+ # CSS for custom styling
54
+ css = """
55
+ #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
56
+ #output-text { font-weight: bold; font-size: 1.2em; }
57
+ """
58
+
59
+ # Gradio App
60
+ with gr.Blocks(css=css) as app:
61
+ with gr.Row():
62
+ gr.HTML('<div id="header">Human or AI Text Detector</div>')
63
+ with gr.Row():
64
+ gr.Markdown(
65
+ """
66
+ [Paper](https://openreview.net/forum?id=z9j7wctoGV) | [Code](https://github.com/xLearn-AU/R-Detect) | [Contact](mailto:1730421718@qq.com)
67
+ """
68
+ )
69
+ with gr.Row():
70
+ input_text = gr.Textbox(
71
+ label="Input Text",
72
+ placeholder="Enter the text to check",
73
+ lines=8,
74
+ )
75
+ with gr.Row():
76
+ model_name = gr.Dropdown(
77
+ [
78
+ "gpt2-medium",
79
+ "gpt2-large",
80
+ "t5-large",
81
+ "t5-small",
82
+ "roberta-base",
83
+ "roberta-base-openai-detector",
84
+ "falcon-rw-1b",
85
+ ],
86
+ label="Select Model",
87
+ value="gpt2-medium",
88
+ )
89
+ with gr.Row():
90
+ submit_button = gr.Button("Run Detection", variant="primary")
91
+ clear_button = gr.Button("Clear", variant="secondary")
92
+ with gr.Row():
93
+ output = gr.Textbox(
94
+ label="Prediction",
95
+ placeholder="Prediction: Human or AI",
96
+ elem_id="output-text",
97
+ )
98
+ with gr.Accordion("Disclaimer", open=False):
99
+ gr.Markdown(
100
+ """
101
+ - **Disclaimer**: This tool is for demonstration purposes only. It is not a foolproof AI detector.
102
+ - **Accuracy**: Results may vary based on input length and quality.
103
+ """
104
+ )
105
+ with gr.Accordion("Citations", open=False):
106
+ gr.Markdown(
107
+ """
108
+ ```
109
+ @inproceedings{zhangs2024MMDMP,
110
+ title={Detecting Machine-Generated Texts by Multi-Population Aware Optimization for Maximum Mean Discrepancy},
111
+ author={Zhang, Shuhai and Song, Yiliao and Yang, Jiahao and Li, Yuanqing and Han, Bo and Tan, Mingkui},
112
+ booktitle = {International Conference on Learning Representations (ICLR)},
113
+ year={2024}
114
+ }
115
+ ```
116
+ """
117
+ )
118
+ submit_button.click(
119
+ run_test_power, inputs=[model_name, "The cat sat on the mat.", input_text], outputs=output
120
+ )
121
+ clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
122
+
123
+ app.launch()