Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
from utils_MMD import extract_features # Adjust the import path | |
from MMD_calculate import mmd_two_sample_baseline # Adjust the import path | |
MINIMUM_TOKENS = 64 | |
THRESHOLD = 0.5 # Threshold for classification | |
def count_tokens(text, tokenizer): | |
""" | |
Counts the number of tokens in the text using the provided tokenizer. | |
""" | |
return len(tokenizer(text).input_ids) | |
def run_test_power(model_name, real_text, generated_text, N=10): | |
""" | |
Runs the test power calculation for provided real and generated texts. | |
Args: | |
model_name (str): Hugging Face model name. | |
real_text (str): Example real text for comparison. | |
generated_text (str): The input text to classify. | |
N (int): Number of repetitions for MMD calculation. | |
Returns: | |
str: "Prediction: Human" or "Prediction: AI". | |
""" | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name).cuda() | |
model.eval() | |
# Ensure minimum token length | |
if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS: | |
return "Too short length. Need a minimum of 64 tokens to calculate Test Power." | |
# Extract features | |
fea_real_ls = extract_features([real_text], tokenizer, model) | |
fea_generated_ls = extract_features([generated_text], tokenizer, model) | |
# Calculate test power list | |
test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=N) | |
# Compute the average test power value | |
power_test_value = sum(test_power_ls) / len(test_power_ls) | |
# Classify the text | |
if power_test_value < THRESHOLD: | |
return "Prediction: Human" | |
else: | |
return "Prediction: AI" | |
# CSS for custom styling | |
css = """ | |
#header { text-align: center; font-size: 1.5em; margin-bottom: 20px; } | |
#output-text { font-weight: bold; font-size: 1.2em; } | |
""" | |
# Gradio App | |
with gr.Blocks(css=css) as app: | |
with gr.Row(): | |
gr.HTML('<div id="header">Human or AI Text Detector</div>') | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
[Paper](https://openreview.net/forum?id=z9j7wctoGV) | [Code](https://github.com/xLearn-AU/R-Detect) | [Contact](mailto:1730421718@qq.com) | |
""" | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter the text to check", | |
lines=8, | |
) | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
[ | |
"gpt2-medium", | |
"gpt2-large", | |
"t5-large", | |
"t5-small", | |
"roberta-base", | |
"roberta-base-openai-detector", | |
"falcon-rw-1b", | |
], | |
label="Select Model", | |
value="gpt2-medium", | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Run Detection", variant="primary") | |
clear_button = gr.Button("Clear", variant="secondary") | |
with gr.Row(): | |
output = gr.Textbox( | |
label="Prediction", | |
placeholder="Prediction: Human or AI", | |
elem_id="output-text", | |
) | |
with gr.Accordion("Disclaimer", open=False): | |
gr.Markdown( | |
""" | |
- **Disclaimer**: This tool is for demonstration purposes only. It is not a foolproof AI detector. | |
- **Accuracy**: Results may vary based on input length and quality. | |
""" | |
) | |
with gr.Accordion("Citations", open=False): | |
gr.Markdown( | |
""" | |
``` | |
@inproceedings{zhangs2024MMDMP, | |
title={Detecting Machine-Generated Texts by Multi-Population Aware Optimization for Maximum Mean Discrepancy}, | |
author={Zhang, Shuhai and Song, Yiliao and Yang, Jiahao and Li, Yuanqing and Han, Bo and Tan, Mingkui}, | |
booktitle = {International Conference on Learning Representations (ICLR)}, | |
year={2024} | |
} | |
``` | |
""" | |
) | |
submit_button.click( | |
run_test_power, inputs=[model_name, "The cat sat on the mat.", input_text], outputs=output | |
) | |
clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output]) | |
app.launch() | |