fix: pass tokenizer explicitly to AbuseDataset and safeguard evaluation step
Browse files- train_abuse_model.py +8 -8
train_abuse_model.py
CHANGED
|
@@ -54,7 +54,7 @@ logger.info("PyTorch version:", torch.__version__)
|
|
| 54 |
# Custom Dataset class
|
| 55 |
|
| 56 |
class AbuseDataset(Dataset):
|
| 57 |
-
def __init__(self, texts, labels):
|
| 58 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|
| 59 |
self.labels = labels
|
| 60 |
|
|
@@ -223,10 +223,9 @@ def run_training():
|
|
| 223 |
param.requires_grad = False
|
| 224 |
|
| 225 |
|
| 226 |
-
train_dataset = AbuseDataset(train_texts, train_labels)
|
| 227 |
-
val_dataset = AbuseDataset(val_texts, val_labels)
|
| 228 |
-
test_dataset = AbuseDataset(test_texts, test_labels)
|
| 229 |
-
|
| 230 |
|
| 231 |
# TrainingArguments for HuggingFace Trainer (logging, saving)
|
| 232 |
training_args = TrainingArguments(
|
|
@@ -270,9 +269,10 @@ def run_training():
|
|
| 270 |
|
| 271 |
# Evaluation
|
| 272 |
try:
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
| 276 |
except Exception as e:
|
| 277 |
logger.exception(f"Evaluation failed: {e}")
|
| 278 |
log_buffer.seek(0)
|
|
|
|
| 54 |
# Custom Dataset class
|
| 55 |
|
| 56 |
class AbuseDataset(Dataset):
|
| 57 |
+
def __init__(self, texts, labels, tokenizer):
|
| 58 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|
| 59 |
self.labels = labels
|
| 60 |
|
|
|
|
| 223 |
param.requires_grad = False
|
| 224 |
|
| 225 |
|
| 226 |
+
train_dataset = AbuseDataset(train_texts, train_labels,tokenizer)
|
| 227 |
+
val_dataset = AbuseDataset(val_texts, val_labels,tokenizer)
|
| 228 |
+
test_dataset = AbuseDataset(test_texts, test_labels,tokenizer)
|
|
|
|
| 229 |
|
| 230 |
# TrainingArguments for HuggingFace Trainer (logging, saving)
|
| 231 |
training_args = TrainingArguments(
|
|
|
|
| 269 |
|
| 270 |
# Evaluation
|
| 271 |
try:
|
| 272 |
+
if 'trainer' in locals():
|
| 273 |
+
label_map = {0.0: "no", 0.5: "plausibly", 1.0: "yes"}
|
| 274 |
+
evaluate_model_with_thresholds(trainer, test_dataset)
|
| 275 |
+
logger.info("Evaluation completed")
|
| 276 |
except Exception as e:
|
| 277 |
logger.exception(f"Evaluation failed: {e}")
|
| 278 |
log_buffer.seek(0)
|