File size: 7,388 Bytes
aa809cf 59afed2 aa809cf 59afed2 aa809cf 59afed2 aa809cf 59afed2 aa809cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""
custom_modeling.py β model-agnostic toxicity and prompt injection wrapper
--------------------------------------------------------------------------
Place in repo root together with:
β’ toxic.keras
β’ PI.keras
Add to config.json:
"auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
"""
import importlib
from functools import lru_cache
import torch
import transformers
import tensorflow as tf
from huggingface_hub import hf_hub_download
# ------------------------------------------------------------------ #
# 1) MIXIN β toxicity and prompt injection filtering logic #
# ------------------------------------------------------------------ #
class _SafeGenerationMixin:
_toxicity_model = None
_pi_model = None
_tox_threshold = 0.6
_pi_threshold = 0.9
# Safety messages
_safe_in_msg = "Sorry, I can't help with that request."
_safe_out_msg = "I'm sorry, but I can't continue with that."
_pi_in_msg = "PI detected at Input level"
_pi_out_msg = "PI detected at output level"
_tokenizer = None
# ---- helpers ----------------------------------------------------
def _device(self):
return next(self.parameters()).device
@property
def _tox_model(self):
if self._toxicity_model is None:
path = hf_hub_download(
repo_id=self.config.name_or_path,
filename="toxic.keras",
)
# β
Force CPU loading for better performance
with tf.device('/CPU:0'):
self._toxicity_model = tf.keras.models.load_model(path, compile=False)
return self._toxicity_model
@property
def _prompt_injection_model(self):
if self._pi_model is None:
path = hf_hub_download(
repo_id=self.config.name_or_path,
filename="PI.keras",
)
# β
Force CPU loading for better performance
with tf.device('/CPU:0'):
self._pi_model = tf.keras.models.load_model(path, compile=False)
return self._pi_model
def _ensure_tokenizer(self):
if self._tokenizer is None:
try:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self.config.name_or_path, trust_remote_code=True
)
except Exception:
pass
def _is_toxic(self, text: str) -> bool:
if not text.strip():
return False
# β
Ensure CPU execution
with tf.device('/CPU:0'):
inputs = tf.constant([text], dtype=tf.string)
prob = float(self._tox_model.predict(inputs, verbose=0)[0, 0])
return prob >= self._tox_threshold
def _has_prompt_injection(self, text: str) -> bool:
if not text.strip():
return False
# β
Ensure CPU execution
with tf.device('/CPU:0'):
inputs = tf.constant([text], dtype=tf.string)
prob = float(self._prompt_injection_model.predict(inputs, verbose=0)[0, 0])
return prob >= self._pi_threshold
def _safe_ids(self, message: str, length: int | None = None):
"""Encode *message* and pad/truncate to *length* tokens (if given)."""
self._ensure_tokenizer()
if self._tokenizer is None:
raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
if length is not None:
pad_id = (
self.config.eos_token_id
if self.config.eos_token_id is not None
else (self.config.pad_token_id or 0)
)
if ids.size(0) < length:
ids = torch.cat(
[ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0
)
else:
ids = ids[:length]
return ids.to(self._device())
# ---- main override ---------------------------------------------
def generate(self, *args, **kwargs):
self._ensure_tokenizer()
# 1) Extract prompt text
prompt_txt = None
if self._tokenizer is not None:
if "input_ids" in kwargs:
prompt_txt = self._tokenizer.decode(
kwargs["input_ids"][0].tolist(), skip_special_tokens=True
)
elif args:
prompt_txt = self._tokenizer.decode(
args[0][0].tolist(), skip_special_tokens=True
)
# 2) Check input for prompt injection (higher priority)
if prompt_txt and self._has_prompt_injection(prompt_txt):
return self._safe_ids(self._pi_in_msg).unsqueeze(0)
# 3) Check input for toxicity
if prompt_txt and self._is_toxic(prompt_txt):
return self._safe_ids(self._safe_in_msg).unsqueeze(0)
# 4) Normal generation
outputs = super().generate(*args, **kwargs)
# 5) Check outputs for safety violations
if self._tokenizer is None:
return outputs
new_seqs = []
for seq in outputs.detach().cpu():
txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
# Check for prompt injection first (higher priority)
if self._has_prompt_injection(txt):
new_seqs.append(self._safe_ids(self._pi_out_msg, length=seq.size(0)))
# Then check for toxicity
elif self._is_toxic(txt):
new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
else:
new_seqs.append(seq)
return torch.stack(new_seqs, dim=0).to(self._device())
# ------------------------------------------------------------------ #
# 2) utilities: resolve base class & cache subclass #
# ------------------------------------------------------------------ #
@lru_cache(None)
def _get_base_cls(arch: str):
if hasattr(transformers, arch):
return getattr(transformers, arch)
stem = arch.replace("ForCausalLM", "").lower()
module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
return getattr(module, arch)
@lru_cache(None)
def _make_safe_subclass(base_cls):
return type(
f"SafeGeneration_{base_cls.__name__}",
(_SafeGenerationMixin, base_cls),
{},
)
# ------------------------------------------------------------------ #
# 3) Dispatcher class β referenced by auto_map #
# ------------------------------------------------------------------ #
class SafeGenerationModel:
@classmethod
def from_pretrained(cls, repo_id, *model_args, **kwargs):
kwargs.setdefault("trust_remote_code", True)
if kwargs.get("torch_dtype") == "auto":
kwargs.pop("torch_dtype")
config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
if not getattr(config, "architectures", None):
raise ValueError("`config.architectures` missing in config.json.")
arch_str = config.architectures[0]
Base = _get_base_cls(arch_str)
Safe = _make_safe_subclass(Base)
kwargs.pop("config", None) # avoid duplicate
return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs) |