Mahesh2841 commited on
Commit
b4ead85
·
verified ·
1 Parent(s): becf062

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +30 -62
custom_modeling.py CHANGED
@@ -1,15 +1,8 @@
1
  """
2
- custom_modeling.py
3
- ------------------
4
- Model-agnostic toxicity wrapper for any Hugging Face causal-LM.
5
-
6
- Add (or keep) in your config.json:
7
- "auto_map": {
8
- "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel"
9
- }
10
-
11
- Files that must live in the repo alongside this script:
12
- • toxic.keras – Keras classifier (sigmoid output: toxic prob)
13
  """
14
 
15
  import importlib
@@ -22,19 +15,15 @@ from huggingface_hub import hf_hub_download
22
 
23
 
24
  # ---------------------------------------------------------------------
25
- # 1) MIXIN – all toxicity logic lives here
26
  # ---------------------------------------------------------------------
27
  class _SafeGenerationMixin:
28
- """Mixin that overrides .generate() to filter toxic prompts / outputs."""
29
-
30
  _toxicity_model = None
31
  _tox_threshold = 0.6
32
- _safe_message = (
33
- "Response is toxic, please be kind to yourself and others."
34
- )
35
  _tokenizer = None
36
 
37
- # ----- helper: load classifier on first use -----------------------
38
  @property
39
  def _tox_model(self):
40
  if self._toxicity_model is None:
@@ -42,12 +31,10 @@ class _SafeGenerationMixin:
42
  repo_id=self.config.name_or_path,
43
  filename="toxic.keras",
44
  )
45
- self._toxicity_model = tf.keras.models.load_model(
46
- path, compile=False
47
- )
48
  return self._toxicity_model
49
 
50
- # ----- helper: load tokenizer (once) ------------------------------
51
  def _ensure_tokenizer(self):
52
  if self._tokenizer is None:
53
  try:
@@ -57,14 +44,13 @@ class _SafeGenerationMixin:
57
  except Exception:
58
  pass
59
 
60
- # ----- helper: tox check -----------------------------------------
61
  def _is_toxic(self, text: str) -> bool:
62
  if not text.strip():
63
  return False
64
  prob = float(self._tox_model.predict([text])[0, 0])
65
  return prob >= self._tox_threshold
66
 
67
- # ----- helper: safe token ids ------------------------------------
68
  def _safe_ids(self, length: int | None = None) -> torch.LongTensor:
69
  self._ensure_tokenizer()
70
  if self._tokenizer is None:
@@ -77,19 +63,16 @@ class _SafeGenerationMixin:
77
  else (self.config.pad_token_id or 0)
78
  )
79
  if ids.size(0) < length:
80
- ids = torch.cat(
81
- [ids, ids.new_full((length - ids.size(0),), pad_id)],
82
- dim=0,
83
- )
84
  else:
85
  ids = ids[:length]
86
  return ids.to(self.device)
87
 
88
- # ----- override generate() ---------------------------------------
89
  def generate(self, *args, **kwargs):
90
  self._ensure_tokenizer()
91
 
92
- # 1) prompt toxicity
93
  prompt_txt = None
94
  if self._tokenizer is not None:
95
  if "input_ids" in kwargs:
@@ -100,81 +83,66 @@ class _SafeGenerationMixin:
100
  prompt_txt = self._tokenizer.decode(
101
  args[0][0].tolist(), skip_special_tokens=True
102
  )
103
-
104
  if prompt_txt and self._is_toxic(prompt_txt):
105
  return self._safe_ids().unsqueeze(0)
106
 
107
  # 2) normal generation
108
- output = super().generate(*args, **kwargs)
109
 
110
- # 3) output toxicity
111
  if self._tokenizer is None:
112
- return output
113
- seqs = output.detach().cpu()
114
  safe = []
115
  for seq in seqs:
116
- txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
117
- if self._is_toxic(txt):
118
  safe.append(self._safe_ids(length=seq.size(0)))
119
  else:
120
  safe.append(seq)
121
- return torch.stack(safe, dim=0).to(self.device)
122
 
123
 
124
  # ---------------------------------------------------------------------
125
- # 2) Resolve base class for the repo’s architecture string
126
  # ---------------------------------------------------------------------
127
  @lru_cache(None)
128
  def _get_base_cls(arch_name: str):
129
- # direct attribute
130
  if hasattr(transformers, arch_name):
131
  return getattr(transformers, arch_name)
132
-
133
- # heuristic import: e.g. LlamaForCausalLM -> transformers.models.llama.modeling_llama
134
  stem = arch_name.replace("ForCausalLM", "").lower()
135
- module_path = f"transformers.models.{stem}.modeling_{stem}"
136
- try:
137
- mod = importlib.import_module(module_path)
138
- return getattr(mod, arch_name)
139
- except Exception as e:
140
- raise ValueError(f"Cannot resolve base class for '{arch_name}': {e}") from e
141
 
142
 
143
  @lru_cache(None)
144
  def _make_dynamic_cls(base_cls):
145
- """Create (and cache) SafeGeneration_<Base> = (Mixin, Base)."""
146
  return type(f"SafeGeneration_{base_cls.__name__}", (_SafeGenerationMixin, base_cls), {})
147
 
148
 
149
  # ---------------------------------------------------------------------
150
- # 3) Dispatcher class – target in `auto_map`
151
  # ---------------------------------------------------------------------
152
  class SafeGenerationModel:
153
- """
154
- Thin dispatcher used by Hugging Face AutoClass.
155
-
156
- It implements only `from_pretrained()`: determine the true base
157
- architecture, build the dynamic subclass, and defer loading to it.
158
- """
159
-
160
  @classmethod
161
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
162
- # propagate trust_remote_code if caller set it
163
  kwargs.setdefault("trust_remote_code", True)
164
 
165
- # 1) load config to know arch string
166
  config = transformers.AutoConfig.from_pretrained(
167
  pretrained_model_name_or_path, **kwargs
168
  )
169
  if not getattr(config, "architectures", None):
170
- raise ValueError("`config.architectures` missing; cannot wrap model.")
171
  arch_name = config.architectures[0]
172
 
173
- # 2) build / retrieve dynamic subclass
174
  base_cls = _get_base_cls(arch_name)
175
  SafeCls = _make_dynamic_cls(base_cls)
176
 
177
- # 3) delegate full loading
 
 
 
178
  return SafeCls.from_pretrained(
179
  pretrained_model_name_or_path,
180
  *model_args,
 
1
  """
2
+ custom_modeling.py – model-agnostic toxicity wrapper
3
+ ----------------------------------------------------
4
+ Keep in config.json:
5
+ "auto_map": {"AutoModelForCausalLM": "custom_modeling.SafeGenerationModel"}
 
 
 
 
 
 
 
6
  """
7
 
8
  import importlib
 
15
 
16
 
17
  # ---------------------------------------------------------------------
18
+ # 1) MIXIN – toxicity logic
19
  # ---------------------------------------------------------------------
20
  class _SafeGenerationMixin:
 
 
21
  _toxicity_model = None
22
  _tox_threshold = 0.6
23
+ _safe_message = "Response is toxic, please be kind to yourself and others."
 
 
24
  _tokenizer = None
25
 
26
+ # ---------- classifier ----------
27
  @property
28
  def _tox_model(self):
29
  if self._toxicity_model is None:
 
31
  repo_id=self.config.name_or_path,
32
  filename="toxic.keras",
33
  )
34
+ self._toxicity_model = tf.keras.models.load_model(path, compile=False)
 
 
35
  return self._toxicity_model
36
 
37
+ # ---------- tokenizer ----------
38
  def _ensure_tokenizer(self):
39
  if self._tokenizer is None:
40
  try:
 
44
  except Exception:
45
  pass
46
 
47
+ # ---------- helpers ----------
48
  def _is_toxic(self, text: str) -> bool:
49
  if not text.strip():
50
  return False
51
  prob = float(self._tox_model.predict([text])[0, 0])
52
  return prob >= self._tox_threshold
53
 
 
54
  def _safe_ids(self, length: int | None = None) -> torch.LongTensor:
55
  self._ensure_tokenizer()
56
  if self._tokenizer is None:
 
63
  else (self.config.pad_token_id or 0)
64
  )
65
  if ids.size(0) < length:
66
+ ids = torch.cat([ids, ids.new_full((length - ids.size(0),), pad_id)], 0)
 
 
 
67
  else:
68
  ids = ids[:length]
69
  return ids.to(self.device)
70
 
71
+ # ---------- override generate ----------
72
  def generate(self, *args, **kwargs):
73
  self._ensure_tokenizer()
74
 
75
+ # 1) prompt check
76
  prompt_txt = None
77
  if self._tokenizer is not None:
78
  if "input_ids" in kwargs:
 
83
  prompt_txt = self._tokenizer.decode(
84
  args[0][0].tolist(), skip_special_tokens=True
85
  )
 
86
  if prompt_txt and self._is_toxic(prompt_txt):
87
  return self._safe_ids().unsqueeze(0)
88
 
89
  # 2) normal generation
90
+ out = super().generate(*args, **kwargs)
91
 
92
+ # 3) output check
93
  if self._tokenizer is None:
94
+ return out
95
+ seqs = out.detach().cpu()
96
  safe = []
97
  for seq in seqs:
98
+ if self._is_toxic(self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)):
 
99
  safe.append(self._safe_ids(length=seq.size(0)))
100
  else:
101
  safe.append(seq)
102
+ return torch.stack(safe, 0).to(self.device)
103
 
104
 
105
  # ---------------------------------------------------------------------
106
+ # 2) helpers – resolve base class & cache dynamic subclass
107
  # ---------------------------------------------------------------------
108
  @lru_cache(None)
109
  def _get_base_cls(arch_name: str):
 
110
  if hasattr(transformers, arch_name):
111
  return getattr(transformers, arch_name)
 
 
112
  stem = arch_name.replace("ForCausalLM", "").lower()
113
+ mod = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
114
+ return getattr(mod, arch_name)
 
 
 
 
115
 
116
 
117
  @lru_cache(None)
118
  def _make_dynamic_cls(base_cls):
 
119
  return type(f"SafeGeneration_{base_cls.__name__}", (_SafeGenerationMixin, base_cls), {})
120
 
121
 
122
  # ---------------------------------------------------------------------
123
+ # 3) dispatcher
124
  # ---------------------------------------------------------------------
125
  class SafeGenerationModel:
 
 
 
 
 
 
 
126
  @classmethod
127
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 
128
  kwargs.setdefault("trust_remote_code", True)
129
 
130
+ # 1) load config to know architecture
131
  config = transformers.AutoConfig.from_pretrained(
132
  pretrained_model_name_or_path, **kwargs
133
  )
134
  if not getattr(config, "architectures", None):
135
+ raise ValueError("`config.architectures` missing – cannot wrap model.")
136
  arch_name = config.architectures[0]
137
 
138
+ # 2) dynamic subclass
139
  base_cls = _get_base_cls(arch_name)
140
  SafeCls = _make_dynamic_cls(base_cls)
141
 
142
+ # 3) drop duplicate 'config' if caller already passed one
143
+ kwargs = {k: v for k, v in kwargs.items() if k != "config"}
144
+
145
+ # 4) delegate real loading
146
  return SafeCls.from_pretrained(
147
  pretrained_model_name_or_path,
148
  *model_args,