Spaces:
Sleeping
Sleeping
Improve UI and reduce repetitiveness of generation
Browse files- api.py +16 -7
- config.ini +4 -1
- demo.py +40 -2
- main.py +1 -1
- model_factory.py +7 -0
- processors.py +57 -7
- schemes.py +1 -0
- stegno.py +11 -7
api.py
CHANGED
@@ -19,7 +19,7 @@ async def encrypt_api(
|
|
19 |
body: EncryptionBody,
|
20 |
):
|
21 |
model, tokenizer = ModelFactory.load_model(body.gen_model)
|
22 |
-
text, msg_rate = generate(
|
23 |
tokenizer=tokenizer,
|
24 |
model=model,
|
25 |
prompt=body.prompt,
|
@@ -32,8 +32,9 @@ async def encrypt_api(
|
|
32 |
private_key=body.private_key,
|
33 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
34 |
num_beams=body.num_beams,
|
|
|
35 |
)
|
36 |
-
return {"text": text, "msg_rate": msg_rate}
|
37 |
|
38 |
|
39 |
@app.post("/decrypt")
|
@@ -78,6 +79,9 @@ async def default_config():
|
|
78 |
"encrypt.default", "max_new_tokens_ratio"
|
79 |
),
|
80 |
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
|
|
|
|
|
|
|
81 |
},
|
82 |
"decrypt": {
|
83 |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
|
@@ -101,9 +105,14 @@ async def default_config():
|
|
101 |
|
102 |
|
103 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
104 |
port = GlobalConfig.get("server", "port")
|
105 |
-
if port is None
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
19 |
body: EncryptionBody,
|
20 |
):
|
21 |
model, tokenizer = ModelFactory.load_model(body.gen_model)
|
22 |
+
text, msg_rate, tokens_info = generate(
|
23 |
tokenizer=tokenizer,
|
24 |
model=model,
|
25 |
prompt=body.prompt,
|
|
|
32 |
private_key=body.private_key,
|
33 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
34 |
num_beams=body.num_beams,
|
35 |
+
repetition_penalty=body.repetition_penalty,
|
36 |
)
|
37 |
+
return {"text": text, "msg_rate": msg_rate, "tokens_info": tokens_info}
|
38 |
|
39 |
|
40 |
@app.post("/decrypt")
|
|
|
79 |
"encrypt.default", "max_new_tokens_ratio"
|
80 |
),
|
81 |
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
|
82 |
+
"repetition_penalty": GlobalConfig.get(
|
83 |
+
"encrypt.default", "repetition_penalty"
|
84 |
+
),
|
85 |
},
|
86 |
"decrypt": {
|
87 |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
|
|
|
105 |
|
106 |
|
107 |
if __name__ == "__main__":
|
108 |
+
# The following are mainly used to satisfy the linter
|
109 |
+
host = GlobalConfig.get("server", "host")
|
110 |
+
host = str(host) if host is not None else "0.0.0.0"
|
111 |
+
|
112 |
port = GlobalConfig.get("server", "port")
|
113 |
+
port = int(port) if port is not None else 8000
|
114 |
+
|
115 |
+
workers = GlobalConfig.get("server", "workers")
|
116 |
+
workers = int(workers) if workers is not None else 1
|
117 |
+
|
118 |
+
uvicorn.run("api:app", host=host, port=port, workers=workers)
|
config.ini
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
[server]
|
2 |
-
|
|
|
|
|
3 |
|
4 |
[models.names]
|
5 |
gpt2 = str:openai-community/gpt2
|
@@ -32,6 +34,7 @@ window_length = int:1
|
|
32 |
private_key = int:0
|
33 |
max_new_tokens_ratio = float:2.0
|
34 |
num_beams = int:4
|
|
|
35 |
|
36 |
[decrypt.default]
|
37 |
gen_model = str:gpt2
|
|
|
1 |
[server]
|
2 |
+
host = str:0.0.0.0
|
3 |
+
port = int:6969
|
4 |
+
workers = int:4
|
5 |
|
6 |
[models.names]
|
7 |
gpt2 = str:openai-community/gpt2
|
|
|
34 |
private_key = int:0
|
35 |
max_new_tokens_ratio = float:2.0
|
36 |
num_beams = int:4
|
37 |
+
repetition_penalty = float:1.0
|
38 |
|
39 |
[decrypt.default]
|
40 |
gen_model = str:gpt2
|
demo.py
CHANGED
@@ -19,9 +19,10 @@ def enc_fn(
|
|
19 |
private_key: int,
|
20 |
max_new_tokens_ratio: float,
|
21 |
num_beams: int,
|
|
|
22 |
):
|
23 |
model, tokenizer = ModelFactory.load_model(gen_model)
|
24 |
-
text, msg_rate = generate(
|
25 |
tokenizer=tokenizer,
|
26 |
model=model,
|
27 |
prompt=prompt,
|
@@ -34,8 +35,32 @@ def enc_fn(
|
|
34 |
private_key=private_key,
|
35 |
max_new_tokens_ratio=max_new_tokens_ratio,
|
36 |
num_beams=num_beams,
|
|
|
37 |
)
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def dec_fn(
|
@@ -89,6 +114,7 @@ if __name__ == "__main__":
|
|
89 |
)
|
90 |
),
|
91 |
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
|
|
|
92 |
],
|
93 |
outputs=[
|
94 |
gr.Textbox(
|
@@ -96,6 +122,18 @@ if __name__ == "__main__":
|
|
96 |
show_label=True,
|
97 |
show_copy_button=True,
|
98 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
gr.Number(label="Percentage of message in text", show_label=True),
|
100 |
],
|
101 |
)
|
|
|
19 |
private_key: int,
|
20 |
max_new_tokens_ratio: float,
|
21 |
num_beams: int,
|
22 |
+
repetition_penalty: float,
|
23 |
):
|
24 |
model, tokenizer = ModelFactory.load_model(gen_model)
|
25 |
+
text, msg_rate, tokens_info = generate(
|
26 |
tokenizer=tokenizer,
|
27 |
model=model,
|
28 |
prompt=prompt,
|
|
|
35 |
private_key=private_key,
|
36 |
max_new_tokens_ratio=max_new_tokens_ratio,
|
37 |
num_beams=num_beams,
|
38 |
+
repetition_penalty=repetition_penalty,
|
39 |
)
|
40 |
+
highlight_base = []
|
41 |
+
for token in tokens_info:
|
42 |
+
stat = None
|
43 |
+
if token["base_msg"] != -1:
|
44 |
+
if token["base_msg"] == token["base_enc"]:
|
45 |
+
stat = "correct"
|
46 |
+
else:
|
47 |
+
stat = "wrong"
|
48 |
+
highlight_base.append((repr(token["token"])[1:-1], stat))
|
49 |
+
|
50 |
+
highlight_byte = []
|
51 |
+
for i, token in enumerate(tokens_info):
|
52 |
+
if i == 0 or tokens_info[i - 1]["byte_id"] != token["byte_id"]:
|
53 |
+
stat = None
|
54 |
+
if token["byte_msg"] != -1:
|
55 |
+
if token["byte_msg"] == token["byte_enc"]:
|
56 |
+
stat = "correct"
|
57 |
+
else:
|
58 |
+
stat = "wrong"
|
59 |
+
highlight_byte.append([repr(token["token"])[1:-1], stat])
|
60 |
+
else:
|
61 |
+
highlight_byte[-1][0] += repr(token["token"])[1:-1]
|
62 |
+
|
63 |
+
return text, highlight_base, highlight_byte, round(msg_rate * 100, 2)
|
64 |
|
65 |
|
66 |
def dec_fn(
|
|
|
114 |
)
|
115 |
),
|
116 |
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
|
117 |
+
gr.Number(float(GlobalConfig.get("encrypt.default", "repetition_penalty"))),
|
118 |
],
|
119 |
outputs=[
|
120 |
gr.Textbox(
|
|
|
122 |
show_label=True,
|
123 |
show_copy_button=True,
|
124 |
),
|
125 |
+
gr.HighlightedText(
|
126 |
+
label="Text containing message (Base highlighted)",
|
127 |
+
combine_adjacent=False,
|
128 |
+
show_legend=True,
|
129 |
+
color_map={"correct": "green", "wrong": "red"},
|
130 |
+
),
|
131 |
+
gr.HighlightedText(
|
132 |
+
label="Text containing message (Byte highlighted)",
|
133 |
+
combine_adjacent=False,
|
134 |
+
show_legend=True,
|
135 |
+
color_map={"correct": "green", "wrong": "red"},
|
136 |
+
),
|
137 |
gr.Number(label="Percentage of message in text", show_label=True),
|
138 |
],
|
139 |
)
|
main.py
CHANGED
@@ -171,7 +171,7 @@ def main(args):
|
|
171 |
print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
|
172 |
print(f" Number of Beams: {args.num_beams}")
|
173 |
print("=" * os.get_terminal_size().columns)
|
174 |
-
text, msg_rate = generate(
|
175 |
tokenizer=tokenizer,
|
176 |
model=model,
|
177 |
prompt=args.prompt,
|
|
|
171 |
print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
|
172 |
print(f" Number of Beams: {args.num_beams}")
|
173 |
print("=" * os.get_terminal_size().columns)
|
174 |
+
text, msg_rate, tokens_info = generate(
|
175 |
tokenizer=tokenizer,
|
176 |
model=model,
|
177 |
prompt=args.prompt,
|
model_factory.py
CHANGED
@@ -70,3 +70,10 @@ class ModelFactory:
|
|
70 |
@classmethod
|
71 |
def get_models_names(cls):
|
72 |
return list(cls.models_names.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
@classmethod
|
71 |
def get_models_names(cls):
|
72 |
return list(cls.models_names.keys())
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def get_model_max_length(cls, name: str):
|
76 |
+
if name in cls.tokenizers:
|
77 |
+
return cls.tokenizers[name].model_max_length
|
78 |
+
else:
|
79 |
+
return 0
|
processors.py
CHANGED
@@ -127,6 +127,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
127 |
self.raw_msg = msg
|
128 |
self.msg = bytes_to_base(msg, self.msg_base)
|
129 |
self.gamma = gamma
|
|
|
130 |
special_tokens = [
|
131 |
tokenizer.bos_token_id,
|
132 |
tokenizer.eos_token_id,
|
@@ -169,20 +170,69 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
169 |
def get_message_len(self):
|
170 |
return len(self.msg)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
def validate(self, input_ids_batch: torch.Tensor):
|
173 |
res = []
|
|
|
174 |
for input_ids in input_ids_batch:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
cnt = 0
|
180 |
-
|
|
|
181 |
if self.raw_msg[i] == enc_msg[i]:
|
182 |
cnt += 1
|
183 |
res.append(cnt / len(self.raw_msg))
|
184 |
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
class DecryptorProcessor(BaseProcessor):
|
@@ -199,7 +249,7 @@ class DecryptorProcessor(BaseProcessor):
|
|
199 |
bytes_msg = []
|
200 |
for i, input_ids in enumerate(input_ids_batch):
|
201 |
msg.append(list())
|
202 |
-
for j in range(
|
203 |
# TODO: this could be slow. Considering reimplement this.
|
204 |
value = self._get_value(input_ids[: j + 1])
|
205 |
msg[i].append(value)
|
|
|
127 |
self.raw_msg = msg
|
128 |
self.msg = bytes_to_base(msg, self.msg_base)
|
129 |
self.gamma = gamma
|
130 |
+
self.tokenizer = tokenizer
|
131 |
special_tokens = [
|
132 |
tokenizer.bos_token_id,
|
133 |
tokenizer.eos_token_id,
|
|
|
170 |
def get_message_len(self):
|
171 |
return len(self.msg)
|
172 |
|
173 |
+
def __map_input_ids(self, input_ids: torch.Tensor, base_arr, byte_arr):
|
174 |
+
byte_enc_msg = [-1 for _ in range(input_ids.size(0))]
|
175 |
+
base_enc_msg = [-1 for _ in range(input_ids.size(0))]
|
176 |
+
base_msg = [-1 for _ in range(input_ids.size(0))]
|
177 |
+
byte_msg = [-1 for _ in range(input_ids.size(0))]
|
178 |
+
|
179 |
+
values_per_byte = get_values_per_byte(self.msg_base)
|
180 |
+
start = self.start_pos % values_per_byte
|
181 |
+
|
182 |
+
for i, b in enumerate(base_arr):
|
183 |
+
base_enc_msg[i] = base_arr[i]
|
184 |
+
byte_enc_msg[i] = byte_arr[(i - start) // values_per_byte]
|
185 |
+
|
186 |
+
for i, b in enumerate(self.msg):
|
187 |
+
base_msg[i + self.start_pos] = b
|
188 |
+
byte_msg[i + self.start_pos] = self.raw_msg[i // values_per_byte]
|
189 |
+
|
190 |
+
return base_msg, byte_msg, base_enc_msg, byte_enc_msg
|
191 |
+
|
192 |
def validate(self, input_ids_batch: torch.Tensor):
|
193 |
res = []
|
194 |
+
tokens_infos = []
|
195 |
for input_ids in input_ids_batch:
|
196 |
+
# Initialization
|
197 |
+
base_arr = []
|
198 |
+
|
199 |
+
# Loop and obtain values of all tokens
|
200 |
+
for i in range(0, input_ids.size(0)):
|
201 |
+
base_arr.append(self._get_value(input_ids[: i + 1]))
|
202 |
+
|
203 |
+
values_per_byte = get_values_per_byte(self.msg_base)
|
204 |
+
|
205 |
+
# Transform the values to bytes
|
206 |
+
start = self.start_pos % values_per_byte
|
207 |
+
byte_arr = base_to_bytes(base_arr[start:], self.msg_base)
|
208 |
+
|
209 |
+
# Construct the
|
210 |
cnt = 0
|
211 |
+
enc_msg = byte_arr[self.start_pos // values_per_byte :]
|
212 |
+
for i in range(min(len(enc_msg), len(self.raw_msg))):
|
213 |
if self.raw_msg[i] == enc_msg[i]:
|
214 |
cnt += 1
|
215 |
res.append(cnt / len(self.raw_msg))
|
216 |
|
217 |
+
base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
|
218 |
+
self.__map_input_ids(input_ids, base_arr, byte_arr)
|
219 |
+
)
|
220 |
+
tokens = []
|
221 |
+
input_strs = [self.tokenizer.decode([input]) for input in input_ids]
|
222 |
+
for i in range(len(base_enc_msg)):
|
223 |
+
tokens.append(
|
224 |
+
{
|
225 |
+
"token": input_strs[i],
|
226 |
+
"base_enc": base_enc_msg[i],
|
227 |
+
"byte_enc": byte_enc_msg[i],
|
228 |
+
"base_msg": base_msg[i],
|
229 |
+
"byte_msg": byte_msg[i],
|
230 |
+
"byte_id": (i - start) // values_per_byte,
|
231 |
+
}
|
232 |
+
)
|
233 |
+
tokens_infos.append(tokens)
|
234 |
+
|
235 |
+
return res, tokens_infos
|
236 |
|
237 |
|
238 |
class DecryptorProcessor(BaseProcessor):
|
|
|
249 |
bytes_msg = []
|
250 |
for i, input_ids in enumerate(input_ids_batch):
|
251 |
msg.append(list())
|
252 |
+
for j in range(shift, len(input_ids)):
|
253 |
# TODO: this could be slow. Considering reimplement this.
|
254 |
value = self._get_value(input_ids[: j + 1])
|
255 |
msg[i].append(value)
|
schemes.py
CHANGED
@@ -20,6 +20,7 @@ class EncryptionBody(BaseModel):
|
|
20 |
"encrypt.default", "max_new_tokens_ratio"
|
21 |
)
|
22 |
num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
|
|
|
23 |
|
24 |
class DecryptionBody(BaseModel):
|
25 |
text: str
|
|
|
20 |
"encrypt.default", "max_new_tokens_ratio"
|
21 |
)
|
22 |
num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
|
23 |
+
repetition_penalty: float = GlobalConfig.get('encrypt.default', "repetition_penalty")
|
24 |
|
25 |
class DecryptionBody(BaseModel):
|
26 |
text: str
|
stegno.py
CHANGED
@@ -20,6 +20,7 @@ def generate(
|
|
20 |
private_key: Union[int, None] = None,
|
21 |
max_new_tokens_ratio: float = 2,
|
22 |
num_beams: int = 4,
|
|
|
23 |
):
|
24 |
"""
|
25 |
Generate the sequence containing the hidden data.
|
@@ -61,17 +62,20 @@ def generate(
|
|
61 |
salt_key=salt_key,
|
62 |
private_key=private_key,
|
63 |
)
|
64 |
-
min_length = start_pos + logits_processor.get_message_len()
|
65 |
-
max_length = int(
|
66 |
start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
|
67 |
)
|
|
|
|
|
68 |
output_tokens = model.generate(
|
69 |
**tokenized_input,
|
70 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
71 |
-
|
72 |
-
|
73 |
do_sample=True,
|
74 |
-
num_beams=num_beams
|
|
|
75 |
)
|
76 |
|
77 |
output_tokens = output_tokens[:, prompt_size:]
|
@@ -81,9 +85,9 @@ def generate(
|
|
81 |
output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
|
82 |
model.device
|
83 |
)
|
84 |
-
msg_rates = logits_processor.validate(output_tokens_post.input_ids)
|
85 |
|
86 |
-
return output_text, msg_rates[0]
|
87 |
|
88 |
|
89 |
def decrypt(
|
|
|
20 |
private_key: Union[int, None] = None,
|
21 |
max_new_tokens_ratio: float = 2,
|
22 |
num_beams: int = 4,
|
23 |
+
repetition_penalty: float = 1.0,
|
24 |
):
|
25 |
"""
|
26 |
Generate the sequence containing the hidden data.
|
|
|
62 |
salt_key=salt_key,
|
63 |
private_key=private_key,
|
64 |
)
|
65 |
+
min_length = prompt_size + start_pos + logits_processor.get_message_len()
|
66 |
+
max_length = prompt_size + int(
|
67 |
start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
|
68 |
)
|
69 |
+
max_length = min(max_length, tokenizer.model_max_length)
|
70 |
+
min_length = min(min_length, max_length)
|
71 |
output_tokens = model.generate(
|
72 |
**tokenized_input,
|
73 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
74 |
+
min_length=min_length,
|
75 |
+
max_length=max_length,
|
76 |
do_sample=True,
|
77 |
+
num_beams=num_beams,
|
78 |
+
repetition_penalty=float(repetition_penalty),
|
79 |
)
|
80 |
|
81 |
output_tokens = output_tokens[:, prompt_size:]
|
|
|
85 |
output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
|
86 |
model.device
|
87 |
)
|
88 |
+
msg_rates, tokens_infos = logits_processor.validate(output_tokens_post.input_ids)
|
89 |
|
90 |
+
return output_text, msg_rates[0], tokens_infos[0]
|
91 |
|
92 |
|
93 |
def decrypt(
|