File size: 7,010 Bytes
aa809cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
custom_modeling.py  – model-agnostic toxicity and prompt injection wrapper
--------------------------------------------------------------------------
Place in repo root together with:
  • toxic.keras
  • PI.keras
Add to config.json:
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
"""

import importlib
from functools import lru_cache

import torch
import transformers
import tensorflow as tf
from huggingface_hub import hf_hub_download


# ------------------------------------------------------------------ #
# 1)  MIXIN – toxicity and prompt injection filtering logic          #
# ------------------------------------------------------------------ #
class _SafeGenerationMixin:
    _toxicity_model = None
    _pi_model = None
    _tox_threshold = 0.6
    _pi_threshold = 0.9

    # Safety messages
    _safe_in_msg = "Sorry, I can't help with that request."
    _safe_out_msg = "I'm sorry, but I can't continue with that."
    _pi_in_msg = "PI detected at Input level"
    _pi_out_msg = "PI detected at output level"

    _tokenizer = None

    # ---- helpers ----------------------------------------------------
    def _device(self):
        return next(self.parameters()).device

    @property
    def _tox_model(self):
        if self._toxicity_model is None:
            path = hf_hub_download(
                repo_id=self.config.name_or_path,
                filename="toxic.keras",
            )
            self._toxicity_model = tf.keras.models.load_model(path, compile=False)
        return self._toxicity_model

    @property
    def _prompt_injection_model(self):
        if self._pi_model is None:
            path = hf_hub_download(
                repo_id=self.config.name_or_path,
                filename="PI.keras",
            )
            self._pi_model = tf.keras.models.load_model(path, compile=False)
        return self._pi_model

    def _ensure_tokenizer(self):
        if self._tokenizer is None:
            try:
                self._tokenizer = transformers.AutoTokenizer.from_pretrained(
                    self.config.name_or_path, trust_remote_code=True
                )
            except Exception:
                pass

    def _is_toxic(self, text: str) -> bool:
        if not text.strip():
            return False
        inputs = tf.constant([text], dtype=tf.string)
        prob = float(self._tox_model.predict(inputs)[0, 0])
        return prob >= self._tox_threshold

    def _has_prompt_injection(self, text: str) -> bool:
        if not text.strip():
            return False
        inputs = tf.constant([text], dtype=tf.string)
        prob = float(self._prompt_injection_model.predict(inputs)[0, 0])
        return prob >= self._pi_threshold

    def _safe_ids(self, message: str, length: int | None = None):
        """Encode *message* and pad/truncate to *length* tokens (if given)."""
        self._ensure_tokenizer()
        if self._tokenizer is None:
            raise RuntimeError("Tokenizer unavailable for safe-message encoding.")

        ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
        if length is not None:
            pad_id = (
                self.config.eos_token_id
                if self.config.eos_token_id is not None
                else (self.config.pad_token_id or 0)
            )
            if ids.size(0) < length:
                ids = torch.cat(
                    [ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0
                )
            else:
                ids = ids[:length]
        return ids.to(self._device())

    # ---- main override ---------------------------------------------
    def generate(self, *args, **kwargs):
        self._ensure_tokenizer()

        # 1) Extract prompt text
        prompt_txt = None
        if self._tokenizer is not None:
            if "input_ids" in kwargs:
                prompt_txt = self._tokenizer.decode(
                    kwargs["input_ids"][0].tolist(), skip_special_tokens=True
                )
            elif args:
                prompt_txt = self._tokenizer.decode(
                    args[0][0].tolist(), skip_special_tokens=True
                )

        # 2) Check input for prompt injection (higher priority)
        if prompt_txt and self._has_prompt_injection(prompt_txt):
            return self._safe_ids(self._pi_in_msg).unsqueeze(0)

        # 3) Check input for toxicity
        if prompt_txt and self._is_toxic(prompt_txt):
            return self._safe_ids(self._safe_in_msg).unsqueeze(0)

        # 4) Normal generation
        outputs = super().generate(*args, **kwargs)

        # 5) Check outputs for safety violations
        if self._tokenizer is None:
            return outputs

        new_seqs = []
        for seq in outputs.detach().cpu():
            txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
            
            # Check for prompt injection first (higher priority)
            if self._has_prompt_injection(txt):
                new_seqs.append(self._safe_ids(self._pi_out_msg, length=seq.size(0)))
            # Then check for toxicity
            elif self._is_toxic(txt):
                new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
            else:
                new_seqs.append(seq)
        
        return torch.stack(new_seqs, dim=0).to(self._device())


# ------------------------------------------------------------------ #
# 2)  utilities: resolve base class & cache subclass                 #
# ------------------------------------------------------------------ #
@lru_cache(None)
def _get_base_cls(arch: str):
    if hasattr(transformers, arch):
        return getattr(transformers, arch)
    stem = arch.replace("ForCausalLM", "").lower()
    module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
    return getattr(module, arch)


@lru_cache(None)
def _make_safe_subclass(base_cls):
    return type(
        f"SafeGeneration_{base_cls.__name__}",
        (_SafeGenerationMixin, base_cls),
        {},
    )


# ------------------------------------------------------------------ #
# 3)  Dispatcher class – referenced by auto_map                      #
# ------------------------------------------------------------------ #
class SafeGenerationModel:
    @classmethod
    def from_pretrained(cls, repo_id, *model_args, **kwargs):
        kwargs.setdefault("trust_remote_code", True)
        if kwargs.get("torch_dtype") == "auto":
            kwargs.pop("torch_dtype")

        config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
        if not getattr(config, "architectures", None):
            raise ValueError("`config.architectures` missing in config.json.")
        arch_str = config.architectures[0]

        Base = _get_base_cls(arch_str)
        Safe = _make_safe_subclass(Base)

        kwargs.pop("config", None)    # avoid duplicate
        return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs)