tnk2908 commited on
Commit
ee83d59
·
1 Parent(s): 4b28b29

Improve UI and reduce repetitiveness of generation

Browse files
Files changed (8) hide show
  1. api.py +16 -7
  2. config.ini +4 -1
  3. demo.py +40 -2
  4. main.py +1 -1
  5. model_factory.py +7 -0
  6. processors.py +57 -7
  7. schemes.py +1 -0
  8. stegno.py +11 -7
api.py CHANGED
@@ -19,7 +19,7 @@ async def encrypt_api(
19
  body: EncryptionBody,
20
  ):
21
  model, tokenizer = ModelFactory.load_model(body.gen_model)
22
- text, msg_rate = generate(
23
  tokenizer=tokenizer,
24
  model=model,
25
  prompt=body.prompt,
@@ -32,8 +32,9 @@ async def encrypt_api(
32
  private_key=body.private_key,
33
  max_new_tokens_ratio=body.max_new_tokens_ratio,
34
  num_beams=body.num_beams,
 
35
  )
36
- return {"text": text, "msg_rate": msg_rate}
37
 
38
 
39
  @app.post("/decrypt")
@@ -78,6 +79,9 @@ async def default_config():
78
  "encrypt.default", "max_new_tokens_ratio"
79
  ),
80
  "num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
 
 
 
81
  },
82
  "decrypt": {
83
  "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
@@ -101,9 +105,14 @@ async def default_config():
101
 
102
 
103
  if __name__ == "__main__":
 
 
 
 
104
  port = GlobalConfig.get("server", "port")
105
- if port is None:
106
- port = 8000
107
- else:
108
- port = int(port)
109
- uvicorn.run("api:app", host="0.0.0.0", port=port, workers=4)
 
 
19
  body: EncryptionBody,
20
  ):
21
  model, tokenizer = ModelFactory.load_model(body.gen_model)
22
+ text, msg_rate, tokens_info = generate(
23
  tokenizer=tokenizer,
24
  model=model,
25
  prompt=body.prompt,
 
32
  private_key=body.private_key,
33
  max_new_tokens_ratio=body.max_new_tokens_ratio,
34
  num_beams=body.num_beams,
35
+ repetition_penalty=body.repetition_penalty,
36
  )
37
+ return {"text": text, "msg_rate": msg_rate, "tokens_info": tokens_info}
38
 
39
 
40
  @app.post("/decrypt")
 
79
  "encrypt.default", "max_new_tokens_ratio"
80
  ),
81
  "num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
82
+ "repetition_penalty": GlobalConfig.get(
83
+ "encrypt.default", "repetition_penalty"
84
+ ),
85
  },
86
  "decrypt": {
87
  "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
 
105
 
106
 
107
  if __name__ == "__main__":
108
+ # The following are mainly used to satisfy the linter
109
+ host = GlobalConfig.get("server", "host")
110
+ host = str(host) if host is not None else "0.0.0.0"
111
+
112
  port = GlobalConfig.get("server", "port")
113
+ port = int(port) if port is not None else 8000
114
+
115
+ workers = GlobalConfig.get("server", "workers")
116
+ workers = int(workers) if workers is not None else 1
117
+
118
+ uvicorn.run("api:app", host=host, port=port, workers=workers)
config.ini CHANGED
@@ -1,5 +1,7 @@
1
  [server]
2
- port = int:42069
 
 
3
 
4
  [models.names]
5
  gpt2 = str:openai-community/gpt2
@@ -32,6 +34,7 @@ window_length = int:1
32
  private_key = int:0
33
  max_new_tokens_ratio = float:2.0
34
  num_beams = int:4
 
35
 
36
  [decrypt.default]
37
  gen_model = str:gpt2
 
1
  [server]
2
+ host = str:0.0.0.0
3
+ port = int:6969
4
+ workers = int:4
5
 
6
  [models.names]
7
  gpt2 = str:openai-community/gpt2
 
34
  private_key = int:0
35
  max_new_tokens_ratio = float:2.0
36
  num_beams = int:4
37
+ repetition_penalty = float:1.0
38
 
39
  [decrypt.default]
40
  gen_model = str:gpt2
demo.py CHANGED
@@ -19,9 +19,10 @@ def enc_fn(
19
  private_key: int,
20
  max_new_tokens_ratio: float,
21
  num_beams: int,
 
22
  ):
23
  model, tokenizer = ModelFactory.load_model(gen_model)
24
- text, msg_rate = generate(
25
  tokenizer=tokenizer,
26
  model=model,
27
  prompt=prompt,
@@ -34,8 +35,32 @@ def enc_fn(
34
  private_key=private_key,
35
  max_new_tokens_ratio=max_new_tokens_ratio,
36
  num_beams=num_beams,
 
37
  )
38
- return text, msg_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def dec_fn(
@@ -89,6 +114,7 @@ if __name__ == "__main__":
89
  )
90
  ),
91
  gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
 
92
  ],
93
  outputs=[
94
  gr.Textbox(
@@ -96,6 +122,18 @@ if __name__ == "__main__":
96
  show_label=True,
97
  show_copy_button=True,
98
  ),
 
 
 
 
 
 
 
 
 
 
 
 
99
  gr.Number(label="Percentage of message in text", show_label=True),
100
  ],
101
  )
 
19
  private_key: int,
20
  max_new_tokens_ratio: float,
21
  num_beams: int,
22
+ repetition_penalty: float,
23
  ):
24
  model, tokenizer = ModelFactory.load_model(gen_model)
25
+ text, msg_rate, tokens_info = generate(
26
  tokenizer=tokenizer,
27
  model=model,
28
  prompt=prompt,
 
35
  private_key=private_key,
36
  max_new_tokens_ratio=max_new_tokens_ratio,
37
  num_beams=num_beams,
38
+ repetition_penalty=repetition_penalty,
39
  )
40
+ highlight_base = []
41
+ for token in tokens_info:
42
+ stat = None
43
+ if token["base_msg"] != -1:
44
+ if token["base_msg"] == token["base_enc"]:
45
+ stat = "correct"
46
+ else:
47
+ stat = "wrong"
48
+ highlight_base.append((repr(token["token"])[1:-1], stat))
49
+
50
+ highlight_byte = []
51
+ for i, token in enumerate(tokens_info):
52
+ if i == 0 or tokens_info[i - 1]["byte_id"] != token["byte_id"]:
53
+ stat = None
54
+ if token["byte_msg"] != -1:
55
+ if token["byte_msg"] == token["byte_enc"]:
56
+ stat = "correct"
57
+ else:
58
+ stat = "wrong"
59
+ highlight_byte.append([repr(token["token"])[1:-1], stat])
60
+ else:
61
+ highlight_byte[-1][0] += repr(token["token"])[1:-1]
62
+
63
+ return text, highlight_base, highlight_byte, round(msg_rate * 100, 2)
64
 
65
 
66
  def dec_fn(
 
114
  )
115
  ),
116
  gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
117
+ gr.Number(float(GlobalConfig.get("encrypt.default", "repetition_penalty"))),
118
  ],
119
  outputs=[
120
  gr.Textbox(
 
122
  show_label=True,
123
  show_copy_button=True,
124
  ),
125
+ gr.HighlightedText(
126
+ label="Text containing message (Base highlighted)",
127
+ combine_adjacent=False,
128
+ show_legend=True,
129
+ color_map={"correct": "green", "wrong": "red"},
130
+ ),
131
+ gr.HighlightedText(
132
+ label="Text containing message (Byte highlighted)",
133
+ combine_adjacent=False,
134
+ show_legend=True,
135
+ color_map={"correct": "green", "wrong": "red"},
136
+ ),
137
  gr.Number(label="Percentage of message in text", show_label=True),
138
  ],
139
  )
main.py CHANGED
@@ -171,7 +171,7 @@ def main(args):
171
  print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
172
  print(f" Number of Beams: {args.num_beams}")
173
  print("=" * os.get_terminal_size().columns)
174
- text, msg_rate = generate(
175
  tokenizer=tokenizer,
176
  model=model,
177
  prompt=args.prompt,
 
171
  print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
172
  print(f" Number of Beams: {args.num_beams}")
173
  print("=" * os.get_terminal_size().columns)
174
+ text, msg_rate, tokens_info = generate(
175
  tokenizer=tokenizer,
176
  model=model,
177
  prompt=args.prompt,
model_factory.py CHANGED
@@ -70,3 +70,10 @@ class ModelFactory:
70
  @classmethod
71
  def get_models_names(cls):
72
  return list(cls.models_names.keys())
 
 
 
 
 
 
 
 
70
  @classmethod
71
  def get_models_names(cls):
72
  return list(cls.models_names.keys())
73
+
74
+ @classmethod
75
+ def get_model_max_length(cls, name: str):
76
+ if name in cls.tokenizers:
77
+ return cls.tokenizers[name].model_max_length
78
+ else:
79
+ return 0
processors.py CHANGED
@@ -127,6 +127,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
127
  self.raw_msg = msg
128
  self.msg = bytes_to_base(msg, self.msg_base)
129
  self.gamma = gamma
 
130
  special_tokens = [
131
  tokenizer.bos_token_id,
132
  tokenizer.eos_token_id,
@@ -169,20 +170,69 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
169
  def get_message_len(self):
170
  return len(self.msg)
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def validate(self, input_ids_batch: torch.Tensor):
173
  res = []
 
174
  for input_ids in input_ids_batch:
175
- values = []
176
- for i in range(self.start_pos, input_ids.size(0)):
177
- values.append(self._get_value(input_ids[: i + 1]))
178
- enc_msg = base_to_bytes(values, self.msg_base)
 
 
 
 
 
 
 
 
 
 
179
  cnt = 0
180
- for i in range(len(self.raw_msg)):
 
181
  if self.raw_msg[i] == enc_msg[i]:
182
  cnt += 1
183
  res.append(cnt / len(self.raw_msg))
184
 
185
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  class DecryptorProcessor(BaseProcessor):
@@ -199,7 +249,7 @@ class DecryptorProcessor(BaseProcessor):
199
  bytes_msg = []
200
  for i, input_ids in enumerate(input_ids_batch):
201
  msg.append(list())
202
- for j in range(self.window_length + shift, len(input_ids)):
203
  # TODO: this could be slow. Considering reimplement this.
204
  value = self._get_value(input_ids[: j + 1])
205
  msg[i].append(value)
 
127
  self.raw_msg = msg
128
  self.msg = bytes_to_base(msg, self.msg_base)
129
  self.gamma = gamma
130
+ self.tokenizer = tokenizer
131
  special_tokens = [
132
  tokenizer.bos_token_id,
133
  tokenizer.eos_token_id,
 
170
  def get_message_len(self):
171
  return len(self.msg)
172
 
173
+ def __map_input_ids(self, input_ids: torch.Tensor, base_arr, byte_arr):
174
+ byte_enc_msg = [-1 for _ in range(input_ids.size(0))]
175
+ base_enc_msg = [-1 for _ in range(input_ids.size(0))]
176
+ base_msg = [-1 for _ in range(input_ids.size(0))]
177
+ byte_msg = [-1 for _ in range(input_ids.size(0))]
178
+
179
+ values_per_byte = get_values_per_byte(self.msg_base)
180
+ start = self.start_pos % values_per_byte
181
+
182
+ for i, b in enumerate(base_arr):
183
+ base_enc_msg[i] = base_arr[i]
184
+ byte_enc_msg[i] = byte_arr[(i - start) // values_per_byte]
185
+
186
+ for i, b in enumerate(self.msg):
187
+ base_msg[i + self.start_pos] = b
188
+ byte_msg[i + self.start_pos] = self.raw_msg[i // values_per_byte]
189
+
190
+ return base_msg, byte_msg, base_enc_msg, byte_enc_msg
191
+
192
  def validate(self, input_ids_batch: torch.Tensor):
193
  res = []
194
+ tokens_infos = []
195
  for input_ids in input_ids_batch:
196
+ # Initialization
197
+ base_arr = []
198
+
199
+ # Loop and obtain values of all tokens
200
+ for i in range(0, input_ids.size(0)):
201
+ base_arr.append(self._get_value(input_ids[: i + 1]))
202
+
203
+ values_per_byte = get_values_per_byte(self.msg_base)
204
+
205
+ # Transform the values to bytes
206
+ start = self.start_pos % values_per_byte
207
+ byte_arr = base_to_bytes(base_arr[start:], self.msg_base)
208
+
209
+ # Construct the
210
  cnt = 0
211
+ enc_msg = byte_arr[self.start_pos // values_per_byte :]
212
+ for i in range(min(len(enc_msg), len(self.raw_msg))):
213
  if self.raw_msg[i] == enc_msg[i]:
214
  cnt += 1
215
  res.append(cnt / len(self.raw_msg))
216
 
217
+ base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
218
+ self.__map_input_ids(input_ids, base_arr, byte_arr)
219
+ )
220
+ tokens = []
221
+ input_strs = [self.tokenizer.decode([input]) for input in input_ids]
222
+ for i in range(len(base_enc_msg)):
223
+ tokens.append(
224
+ {
225
+ "token": input_strs[i],
226
+ "base_enc": base_enc_msg[i],
227
+ "byte_enc": byte_enc_msg[i],
228
+ "base_msg": base_msg[i],
229
+ "byte_msg": byte_msg[i],
230
+ "byte_id": (i - start) // values_per_byte,
231
+ }
232
+ )
233
+ tokens_infos.append(tokens)
234
+
235
+ return res, tokens_infos
236
 
237
 
238
  class DecryptorProcessor(BaseProcessor):
 
249
  bytes_msg = []
250
  for i, input_ids in enumerate(input_ids_batch):
251
  msg.append(list())
252
+ for j in range(shift, len(input_ids)):
253
  # TODO: this could be slow. Considering reimplement this.
254
  value = self._get_value(input_ids[: j + 1])
255
  msg[i].append(value)
schemes.py CHANGED
@@ -20,6 +20,7 @@ class EncryptionBody(BaseModel):
20
  "encrypt.default", "max_new_tokens_ratio"
21
  )
22
  num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
 
23
 
24
  class DecryptionBody(BaseModel):
25
  text: str
 
20
  "encrypt.default", "max_new_tokens_ratio"
21
  )
22
  num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
23
+ repetition_penalty: float = GlobalConfig.get('encrypt.default', "repetition_penalty")
24
 
25
  class DecryptionBody(BaseModel):
26
  text: str
stegno.py CHANGED
@@ -20,6 +20,7 @@ def generate(
20
  private_key: Union[int, None] = None,
21
  max_new_tokens_ratio: float = 2,
22
  num_beams: int = 4,
 
23
  ):
24
  """
25
  Generate the sequence containing the hidden data.
@@ -61,17 +62,20 @@ def generate(
61
  salt_key=salt_key,
62
  private_key=private_key,
63
  )
64
- min_length = start_pos + logits_processor.get_message_len()
65
- max_length = int(
66
  start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
67
  )
 
 
68
  output_tokens = model.generate(
69
  **tokenized_input,
70
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
71
- min_new_tokens=min_length,
72
- max_new_tokens=max_length,
73
  do_sample=True,
74
- num_beams=num_beams
 
75
  )
76
 
77
  output_tokens = output_tokens[:, prompt_size:]
@@ -81,9 +85,9 @@ def generate(
81
  output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
82
  model.device
83
  )
84
- msg_rates = logits_processor.validate(output_tokens_post.input_ids)
85
 
86
- return output_text, msg_rates[0]
87
 
88
 
89
  def decrypt(
 
20
  private_key: Union[int, None] = None,
21
  max_new_tokens_ratio: float = 2,
22
  num_beams: int = 4,
23
+ repetition_penalty: float = 1.0,
24
  ):
25
  """
26
  Generate the sequence containing the hidden data.
 
62
  salt_key=salt_key,
63
  private_key=private_key,
64
  )
65
+ min_length = prompt_size + start_pos + logits_processor.get_message_len()
66
+ max_length = prompt_size + int(
67
  start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
68
  )
69
+ max_length = min(max_length, tokenizer.model_max_length)
70
+ min_length = min(min_length, max_length)
71
  output_tokens = model.generate(
72
  **tokenized_input,
73
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
74
+ min_length=min_length,
75
+ max_length=max_length,
76
  do_sample=True,
77
+ num_beams=num_beams,
78
+ repetition_penalty=float(repetition_penalty),
79
  )
80
 
81
  output_tokens = output_tokens[:, prompt_size:]
 
85
  output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
86
  model.device
87
  )
88
+ msg_rates, tokens_infos = logits_processor.validate(output_tokens_post.input_ids)
89
 
90
+ return output_text, msg_rates[0], tokens_infos[0]
91
 
92
 
93
  def decrypt(