Mahesh2841 commited on
Commit
7c0882c
·
verified ·
1 Parent(s): c23327e

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +45 -12
custom_modeling.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
- custom_modeling.py – model-agnostic toxicity wrapper
3
- ----------------------------------------------------
4
  Place in repo root together with:
5
  • toxic.keras
 
6
  Add to config.json:
7
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
8
  """
@@ -17,15 +18,19 @@ from huggingface_hub import hf_hub_download
17
 
18
 
19
  # ------------------------------------------------------------------ #
20
- # 1) MIXIN – toxicity filtering logic #
21
  # ------------------------------------------------------------------ #
22
  class _SafeGenerationMixin:
23
  _toxicity_model = None
24
- _tox_threshold = 0.6
 
 
25
 
26
- # Separate messages
27
- _safe_in_msg = "Sorry, I can’t help with that request."
28
- _safe_out_msg = "I’m sorry, but I can’t continue with that."
 
 
29
 
30
  _tokenizer = None
31
 
@@ -43,6 +48,16 @@ class _SafeGenerationMixin:
43
  self._toxicity_model = tf.keras.models.load_model(path, compile=False)
44
  return self._toxicity_model
45
 
 
 
 
 
 
 
 
 
 
 
46
  def _ensure_tokenizer(self):
47
  if self._tokenizer is None:
48
  try:
@@ -56,9 +71,16 @@ class _SafeGenerationMixin:
56
  if not text.strip():
57
  return False
58
  inputs = tf.constant([text], dtype=tf.string)
59
- prob = float(self._tox_model.predict(inputs)[0, 0])
60
  return prob >= self._tox_threshold
61
 
 
 
 
 
 
 
 
62
  def _safe_ids(self, message: str, length: int | None = None):
63
  """Encode *message* and pad/truncate to *length* tokens (if given)."""
64
  self._ensure_tokenizer()
@@ -84,7 +106,7 @@ class _SafeGenerationMixin:
84
  def generate(self, *args, **kwargs):
85
  self._ensure_tokenizer()
86
 
87
- # 1) prompt toxicity
88
  prompt_txt = None
89
  if self._tokenizer is not None:
90
  if "input_ids" in kwargs:
@@ -96,23 +118,34 @@ class _SafeGenerationMixin:
96
  args[0][0].tolist(), skip_special_tokens=True
97
  )
98
 
 
 
 
 
 
99
  if prompt_txt and self._is_toxic(prompt_txt):
100
  return self._safe_ids(self._safe_in_msg).unsqueeze(0)
101
 
102
- # 2) normal generation
103
  outputs = super().generate(*args, **kwargs)
104
 
105
- # 3) output toxicity
106
  if self._tokenizer is None:
107
  return outputs
108
 
109
  new_seqs = []
110
  for seq in outputs.detach().cpu():
111
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
112
- if self._is_toxic(txt):
 
 
 
 
 
113
  new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
114
  else:
115
  new_seqs.append(seq)
 
116
  return torch.stack(new_seqs, dim=0).to(self._device())
117
 
118
 
 
1
  """
2
+ custom_modeling.py – model-agnostic toxicity and prompt injection wrapper
3
+ --------------------------------------------------------------------------
4
  Place in repo root together with:
5
  • toxic.keras
6
+ • PI.keras
7
  Add to config.json:
8
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
9
  """
 
18
 
19
 
20
  # ------------------------------------------------------------------ #
21
+ # 1) MIXIN – toxicity and prompt injection filtering logic #
22
  # ------------------------------------------------------------------ #
23
  class _SafeGenerationMixin:
24
  _toxicity_model = None
25
+ _pi_model = None
26
+ _tox_threshold = 0.6
27
+ _pi_threshold = 0.9
28
 
29
+ # Safety messages
30
+ _safe_in_msg = "Sorry, I can't help with that request."
31
+ _safe_out_msg = "I'm sorry, but I can't continue with that."
32
+ _pi_in_msg = "PI detected at Input level"
33
+ _pi_out_msg = "PI detected at output level"
34
 
35
  _tokenizer = None
36
 
 
48
  self._toxicity_model = tf.keras.models.load_model(path, compile=False)
49
  return self._toxicity_model
50
 
51
+ @property
52
+ def _prompt_injection_model(self):
53
+ if self._pi_model is None:
54
+ path = hf_hub_download(
55
+ repo_id=self.config.name_or_path,
56
+ filename="PI.keras",
57
+ )
58
+ self._pi_model = tf.keras.models.load_model(path, compile=False)
59
+ return self._pi_model
60
+
61
  def _ensure_tokenizer(self):
62
  if self._tokenizer is None:
63
  try:
 
71
  if not text.strip():
72
  return False
73
  inputs = tf.constant([text], dtype=tf.string)
74
+ prob = float(self._tox_model.predict(inputs)[0, 0])
75
  return prob >= self._tox_threshold
76
 
77
+ def _has_prompt_injection(self, text: str) -> bool:
78
+ if not text.strip():
79
+ return False
80
+ inputs = tf.constant([text], dtype=tf.string)
81
+ prob = float(self._prompt_injection_model.predict(inputs)[0, 0])
82
+ return prob >= self._pi_threshold
83
+
84
  def _safe_ids(self, message: str, length: int | None = None):
85
  """Encode *message* and pad/truncate to *length* tokens (if given)."""
86
  self._ensure_tokenizer()
 
106
  def generate(self, *args, **kwargs):
107
  self._ensure_tokenizer()
108
 
109
+ # 1) Extract prompt text
110
  prompt_txt = None
111
  if self._tokenizer is not None:
112
  if "input_ids" in kwargs:
 
118
  args[0][0].tolist(), skip_special_tokens=True
119
  )
120
 
121
+ # 2) Check input for prompt injection (higher priority)
122
+ if prompt_txt and self._has_prompt_injection(prompt_txt):
123
+ return self._safe_ids(self._pi_in_msg).unsqueeze(0)
124
+
125
+ # 3) Check input for toxicity
126
  if prompt_txt and self._is_toxic(prompt_txt):
127
  return self._safe_ids(self._safe_in_msg).unsqueeze(0)
128
 
129
+ # 4) Normal generation
130
  outputs = super().generate(*args, **kwargs)
131
 
132
+ # 5) Check outputs for safety violations
133
  if self._tokenizer is None:
134
  return outputs
135
 
136
  new_seqs = []
137
  for seq in outputs.detach().cpu():
138
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
139
+
140
+ # Check for prompt injection first (higher priority)
141
+ if self._has_prompt_injection(txt):
142
+ new_seqs.append(self._safe_ids(self._pi_out_msg, length=seq.size(0)))
143
+ # Then check for toxicity
144
+ elif self._is_toxic(txt):
145
  new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
146
  else:
147
  new_seqs.append(seq)
148
+
149
  return torch.stack(new_seqs, dim=0).to(self._device())
150
 
151