Mahesh2841 commited on
Commit
becf062
·
verified ·
1 Parent(s): 4e28148

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +79 -73
custom_modeling.py CHANGED
@@ -1,19 +1,18 @@
1
  """
2
- SafeGenerationModel
3
- -------------------
4
- Runtime-agnostic toxicity wrapper for ANY causal-LM on Hugging Face.
5
 
6
- Add in `config.json`:
7
- "auto_map": {"AutoModelForCausalLM": "custom_modeling.SafeGenerationModel"}
 
 
8
 
9
- Requires:
10
- • toxic.keras a Keras model that outputs sigmoid-probability for "toxic"
11
- • transformers – >= 4.38
12
- • tensorflow – for the classifier
13
  """
14
 
15
  import importlib
16
- from types import MethodType
17
  from functools import lru_cache
18
 
19
  import torch
@@ -23,34 +22,33 @@ from huggingface_hub import hf_hub_download
23
 
24
 
25
  # ---------------------------------------------------------------------
26
- # 1) MIXIN –- all toxicity logic lives here
27
  # ---------------------------------------------------------------------
28
  class _SafeGenerationMixin:
29
- """
30
- Mixin that overrides `generate()` to filter toxic prompts / outputs.
31
- Must appear *before* the real base LM class in the MRO.
32
- """
33
 
34
- _toxicity_model = None # lazy-loaded TF model
35
- _tox_threshold = 0.6 # edit if needed
36
  _safe_message = (
37
  "Response is toxic, please be kind to yourself and others."
38
  )
39
  _tokenizer = None
40
 
41
- # -------------------- utilities --------------------
42
  @property
43
  def _tox_model(self):
44
- """Load the `.keras` model the first time we need it."""
45
  if self._toxicity_model is None:
46
  path = hf_hub_download(
47
  repo_id=self.config.name_or_path,
48
  filename="toxic.keras",
49
  )
50
- self._toxicity_model = tf.keras.models.load_model(path, compile=False)
 
 
51
  return self._toxicity_model
52
 
53
- def _load_tokenizer(self):
 
54
  if self._tokenizer is None:
55
  try:
56
  self._tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -59,15 +57,16 @@ class _SafeGenerationMixin:
59
  except Exception:
60
  pass
61
 
 
62
  def _is_toxic(self, text: str) -> bool:
63
  if not text.strip():
64
  return False
65
  prob = float(self._tox_model.predict([text])[0, 0])
66
  return prob >= self._tox_threshold
67
 
 
68
  def _safe_ids(self, length: int | None = None) -> torch.LongTensor:
69
- """Return token IDs for the safe-message, padded / truncated to *length*."""
70
- self._load_tokenizer()
71
  if self._tokenizer is None:
72
  raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
73
  ids = self._tokenizer(self._safe_message, return_tensors="pt")["input_ids"][0]
@@ -86,92 +85,99 @@ class _SafeGenerationMixin:
86
  ids = ids[:length]
87
  return ids.to(self.device)
88
 
89
- # -------------------- override generate --------------------
90
  def generate(self, *args, **kwargs):
91
- self._load_tokenizer()
92
 
93
- # 1) Decode prompt toxicity check
94
- prompt_text = None
95
  if self._tokenizer is not None:
96
  if "input_ids" in kwargs:
97
- prompt_text = self._tokenizer.decode(
98
  kwargs["input_ids"][0].tolist(), skip_special_tokens=True
99
  )
100
  elif args:
101
- prompt_text = self._tokenizer.decode(
102
  args[0][0].tolist(), skip_special_tokens=True
103
  )
104
 
105
- if prompt_text and self._is_toxic(prompt_text):
106
  return self._safe_ids().unsqueeze(0)
107
 
108
- # 2) Normal generation (super() == real LM class)
109
- outputs = super().generate(*args, **kwargs)
110
 
111
- # 3) Toxicity check on completions
112
  if self._tokenizer is None:
113
- return outputs # cannot decode → skip
114
-
115
- outs_cpu = outputs.detach().cpu()
116
- safe_batches = []
117
- for seq in outs_cpu:
118
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
119
  if self._is_toxic(txt):
120
- safe_batches.append(self._safe_ids(length=seq.size(0)))
121
  else:
122
- safe_batches.append(seq)
123
- return torch.stack(safe_batches, dim=0).to(self.device)
124
 
125
 
126
  # ---------------------------------------------------------------------
127
- # 2) Helper: find the REAL base class for this config
128
  # ---------------------------------------------------------------------
129
  @lru_cache(None)
130
  def _get_base_cls(arch_name: str):
131
- """
132
- Map 'LlamaForCausalLM' → transformers.LlamaForCausalLM (etc.).
133
- Tries top-level attr first, then imports module heuristically.
134
- """
135
  if hasattr(transformers, arch_name):
136
  return getattr(transformers, arch_name)
137
 
138
- # Fallback: derive submodule from pattern `xxxForCausalLM`
139
  stem = arch_name.replace("ForCausalLM", "").lower()
140
- module_try = f"transformers.models.{stem}.modeling_{stem}"
141
  try:
142
- mod = importlib.import_module(module_try)
143
  return getattr(mod, arch_name)
144
  except Exception as e:
145
- raise ValueError(
146
- f"[SafeGeneration] Could not resolve base class for '{arch_name}': {e}"
147
- ) from e
 
 
 
 
148
 
149
 
150
  # ---------------------------------------------------------------------
151
- # 3) Dispatcher class – what HF actually instantiates
152
  # ---------------------------------------------------------------------
153
  class SafeGenerationModel:
154
  """
155
- Factory / thin wrapper. HF instantiates *this*, passing `config`.
156
- We inspect `config.architectures[0]`, build a
157
- (SafeMixin, RealBaseClass) dynamic subclass, and return an instance.
158
- """
159
-
160
- def __new__(cls, config, *args, **kwargs):
161
- if not getattr(config, "architectures", None):
162
- raise ValueError("`config.architectures` missing – cannot wrap model.")
163
-
164
- base_cls = _get_base_cls(config.architectures[0])
165
-
166
- # Build dynamic subclass only once per *base_cls* (memoised by lru_cache + closure)
167
- DynamicSafeCls = _make_dynamic_cls(base_cls)
168
 
169
- # Finally create and return the actual model instance
170
- return DynamicSafeCls(config, *args, **kwargs)
 
171
 
 
 
 
 
172
 
173
- # -- internal cache to avoid re-creating identical classes ----------------
174
- @lru_cache(None)
175
- def _make_dynamic_cls(base_cls):
176
- name = f"SafeGeneration_{base_cls.__name__}"
177
- return type(name, (_SafeGenerationMixin, base_cls), {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ custom_modeling.py
3
+ ------------------
4
+ Model-agnostic toxicity wrapper for any Hugging Face causal-LM.
5
 
6
+ Add (or keep) in your config.json:
7
+ "auto_map": {
8
+ "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel"
9
+ }
10
 
11
+ Files that must live in the repo alongside this script:
12
+ • toxic.keras – Keras classifier (sigmoid output: toxic prob)
 
 
13
  """
14
 
15
  import importlib
 
16
  from functools import lru_cache
17
 
18
  import torch
 
22
 
23
 
24
  # ---------------------------------------------------------------------
25
+ # 1) MIXIN all toxicity logic lives here
26
  # ---------------------------------------------------------------------
27
  class _SafeGenerationMixin:
28
+ """Mixin that overrides .generate() to filter toxic prompts / outputs."""
 
 
 
29
 
30
+ _toxicity_model = None
31
+ _tox_threshold = 0.6
32
  _safe_message = (
33
  "Response is toxic, please be kind to yourself and others."
34
  )
35
  _tokenizer = None
36
 
37
+ # ----- helper: load classifier on first use -----------------------
38
  @property
39
  def _tox_model(self):
 
40
  if self._toxicity_model is None:
41
  path = hf_hub_download(
42
  repo_id=self.config.name_or_path,
43
  filename="toxic.keras",
44
  )
45
+ self._toxicity_model = tf.keras.models.load_model(
46
+ path, compile=False
47
+ )
48
  return self._toxicity_model
49
 
50
+ # ----- helper: load tokenizer (once) ------------------------------
51
+ def _ensure_tokenizer(self):
52
  if self._tokenizer is None:
53
  try:
54
  self._tokenizer = transformers.AutoTokenizer.from_pretrained(
 
57
  except Exception:
58
  pass
59
 
60
+ # ----- helper: tox check -----------------------------------------
61
  def _is_toxic(self, text: str) -> bool:
62
  if not text.strip():
63
  return False
64
  prob = float(self._tox_model.predict([text])[0, 0])
65
  return prob >= self._tox_threshold
66
 
67
+ # ----- helper: safe token ids ------------------------------------
68
  def _safe_ids(self, length: int | None = None) -> torch.LongTensor:
69
+ self._ensure_tokenizer()
 
70
  if self._tokenizer is None:
71
  raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
72
  ids = self._tokenizer(self._safe_message, return_tensors="pt")["input_ids"][0]
 
85
  ids = ids[:length]
86
  return ids.to(self.device)
87
 
88
+ # ----- override generate() ---------------------------------------
89
  def generate(self, *args, **kwargs):
90
+ self._ensure_tokenizer()
91
 
92
+ # 1) prompt toxicity
93
+ prompt_txt = None
94
  if self._tokenizer is not None:
95
  if "input_ids" in kwargs:
96
+ prompt_txt = self._tokenizer.decode(
97
  kwargs["input_ids"][0].tolist(), skip_special_tokens=True
98
  )
99
  elif args:
100
+ prompt_txt = self._tokenizer.decode(
101
  args[0][0].tolist(), skip_special_tokens=True
102
  )
103
 
104
+ if prompt_txt and self._is_toxic(prompt_txt):
105
  return self._safe_ids().unsqueeze(0)
106
 
107
+ # 2) normal generation
108
+ output = super().generate(*args, **kwargs)
109
 
110
+ # 3) output toxicity
111
  if self._tokenizer is None:
112
+ return output
113
+ seqs = output.detach().cpu()
114
+ safe = []
115
+ for seq in seqs:
 
116
  txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
117
  if self._is_toxic(txt):
118
+ safe.append(self._safe_ids(length=seq.size(0)))
119
  else:
120
+ safe.append(seq)
121
+ return torch.stack(safe, dim=0).to(self.device)
122
 
123
 
124
  # ---------------------------------------------------------------------
125
+ # 2) Resolve base class for the repo’s architecture string
126
  # ---------------------------------------------------------------------
127
  @lru_cache(None)
128
  def _get_base_cls(arch_name: str):
129
+ # direct attribute
 
 
 
130
  if hasattr(transformers, arch_name):
131
  return getattr(transformers, arch_name)
132
 
133
+ # heuristic import: e.g. LlamaForCausalLM -> transformers.models.llama.modeling_llama
134
  stem = arch_name.replace("ForCausalLM", "").lower()
135
+ module_path = f"transformers.models.{stem}.modeling_{stem}"
136
  try:
137
+ mod = importlib.import_module(module_path)
138
  return getattr(mod, arch_name)
139
  except Exception as e:
140
+ raise ValueError(f"Cannot resolve base class for '{arch_name}': {e}") from e
141
+
142
+
143
+ @lru_cache(None)
144
+ def _make_dynamic_cls(base_cls):
145
+ """Create (and cache) SafeGeneration_<Base> = (Mixin, Base)."""
146
+ return type(f"SafeGeneration_{base_cls.__name__}", (_SafeGenerationMixin, base_cls), {})
147
 
148
 
149
  # ---------------------------------------------------------------------
150
+ # 3) Dispatcher class – target in `auto_map`
151
  # ---------------------------------------------------------------------
152
  class SafeGenerationModel:
153
  """
154
+ Thin dispatcher used by Hugging Face AutoClass.
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ It implements only `from_pretrained()`: determine the true base
157
+ architecture, build the dynamic subclass, and defer loading to it.
158
+ """
159
 
160
+ @classmethod
161
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
162
+ # propagate trust_remote_code if caller set it
163
+ kwargs.setdefault("trust_remote_code", True)
164
 
165
+ # 1) load config to know arch string
166
+ config = transformers.AutoConfig.from_pretrained(
167
+ pretrained_model_name_or_path, **kwargs
168
+ )
169
+ if not getattr(config, "architectures", None):
170
+ raise ValueError("`config.architectures` missing; cannot wrap model.")
171
+ arch_name = config.architectures[0]
172
+
173
+ # 2) build / retrieve dynamic subclass
174
+ base_cls = _get_base_cls(arch_name)
175
+ SafeCls = _make_dynamic_cls(base_cls)
176
+
177
+ # 3) delegate full loading
178
+ return SafeCls.from_pretrained(
179
+ pretrained_model_name_or_path,
180
+ *model_args,
181
+ config=config,
182
+ **kwargs,
183
+ )