Mohamad-Jaallouk commited on
Commit
6a49579
·
verified ·
1 Parent(s): 16f821a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+
7
+ class ModelProcessor:
8
+ def __init__(self, repo_id="HuggingFaceTB/cosmo-1b"):
9
+ # Initialize the tokenizer
10
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
11
+
12
+ # Initialize and configure the model
13
+ self.model = AutoModelForCausalLM.from_pretrained(
14
+ repo_id, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True
15
+ )
16
+ self.model.eval() # Set the model to evaluation mode
17
+
18
+ # Set padding token as end-of-sequence token
19
+ self.tokenizer.pad_token = self.tokenizer.eos_token
20
+
21
+ @torch.inference_mode()
22
+ def process_data_and_compute_statistics(self, prompt):
23
+ # Tokenize the prompt and move to the device
24
+ tokens = self.tokenizer(
25
+ prompt, return_tensors="pt", truncation=True, max_length=512
26
+ ).to(self.model.device)
27
+
28
+ # Get the model outputs and logits
29
+ outputs = self.model(tokens["input_ids"])
30
+ logits = outputs.logits
31
+
32
+ # Shift right to align with logits' prediction position
33
+ shifted_labels = tokens["input_ids"][..., 1:].contiguous()
34
+ shifted_logits = logits[..., :-1, :].contiguous()
35
+
36
+ # Calculate entropy
37
+ shifted_probs = torch.softmax(shifted_logits, dim=-1)
38
+ shifted_log_probs = torch.log_softmax(shifted_logits, dim=-1)
39
+ entropy = -torch.sum(shifted_probs * shifted_log_probs, dim=-1).squeeze()
40
+
41
+ # Flatten the logits and labels
42
+ logits_flat = shifted_logits.view(-1, shifted_logits.size(-1))
43
+ labels_flat = shifted_labels.view(-1)
44
+
45
+ # Calculate the negative log-likelihood loss
46
+ probabilities_flat = torch.softmax(logits_flat, dim=-1)
47
+ true_class_probabilities = probabilities_flat.gather(
48
+ 1, labels_flat.unsqueeze(1)
49
+ ).squeeze(1)
50
+ nll = -torch.log(
51
+ true_class_probabilities.clamp(min=1e-9)
52
+ ) # Clamp to prevent log(0)
53
+
54
+ ranks = (
55
+ shifted_logits.argsort(dim=-1, descending=True)
56
+ == shifted_labels.unsqueeze(-1)
57
+ ).nonzero()[:, -1]
58
+
59
+ if entropy.clamp(max=4).median() < 2.0:
60
+ return 1
61
+
62
+ return 1 if (ranks.clamp(max=4) * nll.clamp(max=4)).mean() < 5.2 else 0
63
+
64
+
65
+ processor = ModelProcessor()
66
+
67
+
68
+ def detect(prompt):
69
+ prediction = processor.process_data_and_compute_statistics(prompt)
70
+ if prediction == 1:
71
+ return "The text is likely **generated** by a language model."
72
+ else:
73
+ return "The text is likely **not generated** by a language model."
74
+
75
+
76
+ with gr.Blocks(
77
+ css="""
78
+ .gradio-container {
79
+ max-width: 800px;
80
+ margin: 0 auto;
81
+ }
82
+ .gr-box {
83
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
84
+ padding: 20px;
85
+ border-radius: 4px;
86
+ }
87
+ .gr-button {
88
+ background-color: #007bff;
89
+ color: white; padding: 10px 20px;
90
+ border-radius: 4px;
91
+ }
92
+ .gr-button:hover {
93
+ background-color: }
94
+ .hyperlinks a {
95
+ margin-right: 10px;
96
+ }
97
+ """
98
+ ) as demo:
99
+ with gr.Row():
100
+ with gr.Column(scale=3):
101
+ gr.Markdown("# ENTELL Model Detection")
102
+ with gr.Column(scale=1):
103
+ gr.HTML(
104
+ """
105
+ <p>
106
+ <a href="" target="_blank">paper</a>
107
+
108
+ <a href="" target="_blank">code</a>
109
+
110
+ <a href="mailto:[email protected]" target="_blank">contact</a>
111
+ """,
112
+ elem_classes="hyperlinks",
113
+ )
114
+ with gr.Row():
115
+ with gr.Column():
116
+ prompt = gr.Textbox(
117
+ lines=8,
118
+ placeholder="Type your prompt here...",
119
+ label="Prompt",
120
+ )
121
+ submit_btn = gr.Button("Submit", variant="primary")
122
+ output = gr.Markdown()
123
+
124
+ submit_btn.click(fn=detect, inputs=prompt, outputs=output)
125
+
126
+ demo.launch()