akaafridi commited on
Commit
9381a8f
·
verified ·
1 Parent(s): 41038fb

Update src/classifier.py

Browse files
Files changed (1) hide show
  1. src/classifier.py +43 -85
src/classifier.py CHANGED
@@ -3,26 +3,13 @@ classifier.py
3
  -------------
4
 
5
  This module defines utilities for classifying the relationship between a
6
- claim and candidate sentences. It leverages a cross-encoder model
7
- pretrained on the Natural Language Inference (NLI) task to assign
8
- labels indicating whether each candidate sentence supports, contradicts,
9
- or is neutral with respect to the claim. When the required
10
- transformers components cannot be loaded (e.g. due to missing
11
- dependencies or lack of network access), the module falls back to a
12
- lightweight heuristic-based classifier.
13
-
14
- The classifier returns one of three string labels for each input pair:
15
-
16
- * ``"support"`` – The sentence entails the claim.
17
- * ``"contradict"`` – The sentence contradicts the claim.
18
- * ``"neutral"`` – The sentence neither supports nor contradicts the claim.
19
-
20
- Example:
21
-
22
- >>> from classifier import classify
23
- >>> labels = classify("The sky is blue", ["The sky is blue on a clear day.", "Grass is green."])
24
- >>> print(labels) # ["support", "neutral"]
25
 
 
 
 
 
26
  """
27
 
28
  from __future__ import annotations
@@ -34,18 +21,15 @@ import numpy as np
34
 
35
  logger = logging.getLogger(__name__)
36
 
37
- _nli_model = None # type: ignore
38
- _nli_tokenizer = None # type: ignore
39
- _use_transformers = False
40
 
41
 
42
  def _load_nli_model(model_name: str = "cross-encoder/nli-roberta-base"):
43
- """Lazy-load the NLI cross-encoder model and tokenizer.
44
-
45
- If loading fails, the fallback heuristic classifier will be used.
46
- """
47
  global _nli_model, _nli_tokenizer, _use_transformers
48
- if _nli_model is not None or _use_transformers:
49
  return
50
  try:
51
  from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore
@@ -56,7 +40,7 @@ def _load_nli_model(model_name: str = "cross-encoder/nli-roberta-base"):
56
  _use_transformers = True
57
  except Exception as exc:
58
  logger.warning(
59
- "Failed to load NLI model '%s'. Falling back to heuristic classifier: %s",
60
  model_name,
61
  exc,
62
  )
@@ -72,98 +56,72 @@ def _classify_with_nli(claim: str, sentences: List[str], batch_size: int = 16) -
72
 
73
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
  _nli_model.to(device)
75
- labels_out: List[str] = []
76
 
77
- # Map the model's label indices to human-readable labels.
78
- # The order for 'cross-encoder/nli-roberta-base' is [contradiction, entailment, neutral].
79
  id2label = {0: "contradict", 1: "support", 2: "neutral"}
80
 
81
- # Process in batches to avoid OOM
82
  for start in range(0, len(sentences), batch_size):
83
- batch_sentences = sentences[start:start + batch_size]
84
- encoded = _nli_tokenizer(
85
- [claim] * len(batch_sentences),
86
- batch_sentences,
87
  return_tensors="pt",
88
  truncation=True,
89
  padding=True,
90
  ).to(device)
91
  with torch.no_grad():
92
- outputs = _nli_model(**encoded)
93
- logits = outputs.logits.cpu().numpy()
94
  preds = logits.argmax(axis=1)
95
  labels_out.extend([id2label.get(int(p), "neutral") for p in preds])
96
  return labels_out
97
 
98
 
99
  def _heuristic_classify(claim: str, sentences: List[str]) -> List[str]:
100
- """Simple heuristic classifier used when transformers are unavailable.
101
-
102
- The heuristic checks for lexical overlap between the claim and
103
- candidate sentences and the presence of negation words. It aims to
104
- approximate entailment/contradiction detection without external
105
- dependencies. The rules are very simple and should not be relied on
106
- for production use, but they provide a reasonable fallback.
107
- """
108
  import re
109
 
110
  claim_tokens = set(re.findall(r"\b\w+\b", claim.lower()))
111
- negations = {"not", "no", "never", "none", "cannot", "n't"}
112
- labels: List[str] = []
113
- for sent in sentences:
114
- sent_tokens = set(re.findall(r"\b\w+\b", sent.lower()))
115
- overlap = claim_tokens & sent_tokens
116
- has_neg = any(tok in sent_tokens for tok in negations)
117
  if overlap and not has_neg:
118
- labels.append("support")
119
  elif overlap and has_neg:
120
- labels.append("contradict")
121
  else:
122
- labels.append("neutral")
123
- return labels
124
 
125
 
126
  def classify(claim: str, sentences: Iterable[str], batch_size: int = 16) -> List[str]:
127
- """Classify each sentence in ``sentences`` relative to ``claim``.
128
-
129
- Parameters
130
- ----------
131
- claim:
132
- The claim or hypothesis to compare against.
133
-
134
- sentences:
135
- An iterable of candidate sentences.
136
-
137
- batch_size:
138
- Batch size used when running inference with the transformer model.
139
-
140
- Returns
141
- -------
142
- List[str]
143
- A list of labels (``"support"``, ``"contradict"``, or ``"neutral"``)
144
- corresponding to each input sentence. The ordering of the
145
- labels matches the ordering of the input sentences.
146
- """
147
- sentences_list = list(sentences)
148
- if not sentences_list:
149
  return []
150
 
151
- if _nli_model is None and not _use_transformers:
 
152
  _load_nli_model()
153
 
154
  if _use_transformers and _nli_model is not None and _nli_tokenizer is not None:
155
  try:
156
- return _classify_with_nli(claim, sentences_list, batch_size=batch_size)
157
  except Exception as exc:
158
  logger.warning(
159
- "NLI classification failed. Falling back to heuristic classifier: %s",
160
  exc,
161
  )
162
- # Mark transformers as unusable for subsequent calls
163
- global _use_transformers
164
  _use_transformers = False
165
  _nli_model = None
166
  _nli_tokenizer = None
167
 
168
- # Heuristic fallback
169
- return _heuristic_classify(claim, sentences_list)
 
3
  -------------
4
 
5
  This module defines utilities for classifying the relationship between a
6
+ claim and candidate sentences. It tries to use a transformers NLI
7
+ cross-encoder; if that fails, it falls back to a lightweight heuristic.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ Labels:
10
+ - "support" (entailment)
11
+ - "contradict" (contradiction)
12
+ - "neutral"
13
  """
14
 
15
  from __future__ import annotations
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+ _nli_model = None # type: ignore
25
+ _nli_tokenizer = None # type: ignore
26
+ _use_transformers = False # whether NLI model is successfully loaded
27
 
28
 
29
  def _load_nli_model(model_name: str = "cross-encoder/nli-roberta-base"):
30
+ """Lazy-load the NLI model and tokenizer; set fallback flag on failure."""
 
 
 
31
  global _nli_model, _nli_tokenizer, _use_transformers
32
+ if _nli_model is not None and _nli_tokenizer is not None and _use_transformers:
33
  return
34
  try:
35
  from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore
 
40
  _use_transformers = True
41
  except Exception as exc:
42
  logger.warning(
43
+ "Failed to load NLI model '%s'. Falling back to heuristic: %s",
44
  model_name,
45
  exc,
46
  )
 
56
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  _nli_model.to(device)
 
59
 
60
+ # Order for nli-roberta-base: [contradiction, entailment, neutral]
 
61
  id2label = {0: "contradict", 1: "support", 2: "neutral"}
62
 
63
+ labels_out: List[str] = []
64
  for start in range(0, len(sentences), batch_size):
65
+ batch = sentences[start : start + batch_size]
66
+ enc = _nli_tokenizer(
67
+ [claim] * len(batch),
68
+ batch,
69
  return_tensors="pt",
70
  truncation=True,
71
  padding=True,
72
  ).to(device)
73
  with torch.no_grad():
74
+ logits = _nli_model(**enc).logits.cpu().numpy()
 
75
  preds = logits.argmax(axis=1)
76
  labels_out.extend([id2label.get(int(p), "neutral") for p in preds])
77
  return labels_out
78
 
79
 
80
  def _heuristic_classify(claim: str, sentences: List[str]) -> List[str]:
81
+ """Very simple heuristic fallback (lexical overlap + negation)."""
 
 
 
 
 
 
 
82
  import re
83
 
84
  claim_tokens = set(re.findall(r"\b\w+\b", claim.lower()))
85
+ neg = {"not", "no", "never", "none", "cannot", "n't"}
86
+ out: List[str] = []
87
+ for s in sentences:
88
+ s_tokens = set(re.findall(r"\b\w+\b", s.lower()))
89
+ overlap = bool(claim_tokens & s_tokens)
90
+ has_neg = any(tok in s_tokens for tok in neg)
91
  if overlap and not has_neg:
92
+ out.append("support")
93
  elif overlap and has_neg:
94
+ out.append("contradict")
95
  else:
96
+ out.append("neutral")
97
+ return out
98
 
99
 
100
  def classify(claim: str, sentences: Iterable[str], batch_size: int = 16) -> List[str]:
101
+ """Return a label for each sentence relative to the claim."""
102
+ # IMPORTANT: declare globals first since we modify them on failure
103
+ global _nli_model, _nli_tokenizer, _use_transformers
104
+
105
+ sents = list(sentences)
106
+ if not sents:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  return []
108
 
109
+ # Try to ensure model is loaded
110
+ if _nli_model is None or _nli_tokenizer is None:
111
  _load_nli_model()
112
 
113
  if _use_transformers and _nli_model is not None and _nli_tokenizer is not None:
114
  try:
115
+ return _classify_with_nli(claim, sents, batch_size=batch_size)
116
  except Exception as exc:
117
  logger.warning(
118
+ "NLI classification failed; switching to heuristic. Error: %s",
119
  exc,
120
  )
121
+ # Mark as unusable so subsequent calls go straight to heuristic
 
122
  _use_transformers = False
123
  _nli_model = None
124
  _nli_tokenizer = None
125
 
126
+ # Fallback
127
+ return _heuristic_classify(claim, sents)