Commit
·
ead9c37
1
Parent(s):
9a1ffc0
Update model/train.py
Browse filesCross Validation Implementation
- model/train.py +355 -116
model/train.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import seaborn as sns
|
2 |
import matplotlib.pyplot as plt
|
3 |
from sklearn.feature_selection import SelectKBest, chi2
|
@@ -10,7 +13,7 @@ from sklearn.metrics import (
|
|
10 |
)
|
11 |
from sklearn.model_selection import (
|
12 |
train_test_split, cross_val_score, GridSearchCV,
|
13 |
-
StratifiedKFold, validation_curve
|
14 |
)
|
15 |
from sklearn.ensemble import RandomForestClassifier
|
16 |
from sklearn.linear_model import LogisticRegression
|
@@ -26,7 +29,7 @@ import sys
|
|
26 |
import os
|
27 |
import time
|
28 |
from datetime import datetime, timedelta
|
29 |
-
from typing import Dict, Tuple, Optional, Any
|
30 |
import warnings
|
31 |
import re
|
32 |
warnings.filterwarnings('ignore')
|
@@ -143,7 +146,7 @@ class ProgressTracker:
|
|
143 |
print(f"\n{self.description} completed in {timedelta(seconds=int(total_time))}")
|
144 |
|
145 |
|
146 |
-
def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_folds: int =
|
147 |
"""Estimate training time based on dataset characteristics"""
|
148 |
|
149 |
# Base time estimates (in seconds) based on empirical testing
|
@@ -173,12 +176,15 @@ def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_fol
|
|
173 |
estimates['vectorization'] = base_times['vectorization']
|
174 |
estimates['feature_selection'] = base_times['feature_selection']
|
175 |
|
176 |
-
# Model training
|
177 |
for model_name, multiplier in tuning_multipliers.items():
|
178 |
model_time = base_times['simple_training'] * multiplier * cv_multiplier
|
179 |
estimates[f'{model_name}_training'] = model_time
|
180 |
estimates[f'{model_name}_evaluation'] = base_times['evaluation']
|
181 |
|
|
|
|
|
|
|
182 |
# Model saving
|
183 |
estimates['model_saving'] = 1.0
|
184 |
|
@@ -198,14 +204,189 @@ def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_fol
|
|
198 |
}
|
199 |
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
class RobustModelTrainer:
|
202 |
-
"""Production-ready model trainer with comprehensive
|
203 |
|
204 |
def __init__(self):
|
205 |
self.setup_paths()
|
206 |
self.setup_training_config()
|
207 |
self.setup_models()
|
208 |
self.progress_tracker = None
|
|
|
209 |
|
210 |
def setup_paths(self):
|
211 |
"""Setup all necessary paths with proper permissions"""
|
@@ -232,11 +413,11 @@ class RobustModelTrainer:
|
|
232 |
self.evaluation_path = self.results_dir / "evaluation_results.json"
|
233 |
|
234 |
def setup_training_config(self):
|
235 |
-
"""Setup training configuration"""
|
236 |
self.test_size = 0.2
|
237 |
self.validation_size = 0.1
|
238 |
self.random_state = 42
|
239 |
-
self.cv_folds =
|
240 |
self.max_features = 5000 # Reduced for speed
|
241 |
self.min_df = 1 # More lenient for small datasets
|
242 |
self.max_df = 0.95
|
@@ -312,13 +493,13 @@ class RobustModelTrainer:
|
|
312 |
if len(unique_labels) < 2:
|
313 |
return False, None, f"Need at least 2 classes, found: {unique_labels}"
|
314 |
|
315 |
-
# Check minimum sample size
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
logger.
|
322 |
|
323 |
# Check class balance
|
324 |
label_counts = df['label'].value_counts()
|
@@ -378,7 +559,7 @@ class RobustModelTrainer:
|
|
378 |
return pipeline
|
379 |
|
380 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
381 |
-
"""Comprehensive model evaluation with
|
382 |
|
383 |
if self.progress_tracker:
|
384 |
self.progress_tracker.update("Evaluating model")
|
@@ -400,40 +581,22 @@ class RobustModelTrainer:
|
|
400 |
cm = confusion_matrix(y_test, y_pred)
|
401 |
metrics['confusion_matrix'] = cm.tolist()
|
402 |
|
403 |
-
#
|
404 |
-
if X_train is not None and y_train is not None
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
),
|
420 |
-
scoring='f1_weighted',
|
421 |
-
n_jobs=1 # Single job for small datasets
|
422 |
-
)
|
423 |
-
metrics['cv_scores'] = {
|
424 |
-
'mean': float(cv_scores.mean()),
|
425 |
-
'std': float(cv_scores.std()),
|
426 |
-
'scores': cv_scores.tolist(),
|
427 |
-
'folds_used': cv_folds
|
428 |
-
}
|
429 |
-
else:
|
430 |
-
metrics['cv_scores'] = {'note': 'Dataset too small for reliable CV'}
|
431 |
-
except Exception as e:
|
432 |
-
logger.warning(f"Cross-validation failed: {e}")
|
433 |
-
metrics['cv_scores'] = {'note': f'CV failed: {str(e)}'}
|
434 |
-
else:
|
435 |
-
metrics['cv_scores'] = {'note': 'Skipped for very small dataset'}
|
436 |
-
|
437 |
# Training accuracy for overfitting detection
|
438 |
try:
|
439 |
if X_train is not None and y_train is not None:
|
@@ -447,11 +610,11 @@ class RobustModelTrainer:
|
|
447 |
|
448 |
return metrics
|
449 |
|
450 |
-
def
|
451 |
-
"""Perform hyperparameter tuning with cross-validation"""
|
452 |
|
453 |
if self.progress_tracker:
|
454 |
-
self.progress_tracker.update(f"Tuning {model_name}")
|
455 |
|
456 |
try:
|
457 |
# Set the model in the pipeline
|
@@ -461,63 +624,68 @@ class RobustModelTrainer:
|
|
461 |
if len(X_train) < 20:
|
462 |
logger.info(f"Skipping hyperparameter tuning for {model_name} due to small dataset")
|
463 |
pipeline.fit(X_train, y_train)
|
|
|
|
|
|
|
|
|
464 |
return pipeline, {
|
465 |
'best_params': 'default_parameters',
|
466 |
-
'best_score': 'not_calculated',
|
467 |
'best_estimator': pipeline,
|
|
|
468 |
'note': 'Hyperparameter tuning skipped for small dataset'
|
469 |
}
|
470 |
|
471 |
# Get parameter grid
|
472 |
param_grid = self.models[model_name]['param_grid']
|
473 |
|
474 |
-
#
|
475 |
-
|
476 |
-
min_samples_per_fold = 3
|
477 |
-
max_folds = n_samples // min_samples_per_fold
|
478 |
-
cv_folds = max(2, min(self.cv_folds, max_folds))
|
479 |
|
480 |
-
|
481 |
-
# Fallback to simple training
|
482 |
-
logger.info(f"Dataset too small for CV, using simple training for {model_name}")
|
483 |
-
pipeline.fit(X_train, y_train)
|
484 |
-
return pipeline, {
|
485 |
-
'best_params': 'default_parameters',
|
486 |
-
'best_score': 'not_calculated',
|
487 |
-
'best_estimator': pipeline,
|
488 |
-
'note': 'Simple training used due to very small dataset'
|
489 |
-
}
|
490 |
-
|
491 |
-
# Create GridSearchCV
|
492 |
grid_search = GridSearchCV(
|
493 |
pipeline,
|
494 |
param_grid,
|
495 |
-
cv=
|
496 |
-
shuffle=True, random_state=self.random_state),
|
497 |
scoring='f1_weighted',
|
498 |
-
n_jobs=1, # Single job for
|
499 |
-
verbose=0
|
|
|
500 |
)
|
501 |
|
502 |
# Fit grid search
|
|
|
503 |
grid_search.fit(X_train, y_train)
|
504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
# Extract results
|
506 |
tuning_results = {
|
507 |
'best_params': grid_search.best_params_,
|
508 |
'best_score': float(grid_search.best_score_),
|
509 |
'best_estimator': grid_search.best_estimator_,
|
510 |
-
'cv_folds_used':
|
511 |
-
'
|
|
|
512 |
'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
|
513 |
'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
|
|
|
514 |
'params': grid_search.cv_results_['params']
|
515 |
}
|
516 |
}
|
517 |
|
518 |
logger.info(f"Hyperparameter tuning completed for {model_name}")
|
519 |
-
logger.info(f"Best score: {grid_search.best_score_:.4f}")
|
520 |
logger.info(f"Best params: {grid_search.best_params_}")
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
return grid_search.best_estimator_, tuning_results
|
523 |
|
@@ -527,29 +695,37 @@ class RobustModelTrainer:
|
|
527 |
try:
|
528 |
pipeline.set_params(model=self.models[model_name]['model'])
|
529 |
pipeline.fit(X_train, y_train)
|
530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
except Exception as e2:
|
532 |
logger.error(f"Fallback training also failed for {model_name}: {str(e2)}")
|
533 |
raise Exception(f"Both hyperparameter tuning and fallback training failed: {str(e)} | {str(e2)}")
|
534 |
|
535 |
def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
|
536 |
-
"""Train and evaluate multiple models"""
|
537 |
|
538 |
results = {}
|
539 |
|
540 |
for model_name in self.models.keys():
|
541 |
-
logger.info(f"Training {model_name}...")
|
542 |
|
543 |
try:
|
544 |
# Create pipeline
|
545 |
pipeline = self.create_preprocessing_pipeline()
|
546 |
|
547 |
-
# Hyperparameter tuning
|
548 |
-
best_model, tuning_results = self.
|
549 |
pipeline, X_train, y_train, model_name
|
550 |
)
|
551 |
|
552 |
-
# Comprehensive evaluation
|
553 |
evaluation_metrics = self.comprehensive_evaluation(
|
554 |
best_model, X_test, y_test, X_train, y_train
|
555 |
)
|
@@ -562,8 +738,15 @@ class RobustModelTrainer:
|
|
562 |
'training_time': datetime.now().isoformat()
|
563 |
}
|
564 |
|
565 |
-
|
566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
|
568 |
except Exception as e:
|
569 |
logger.error(f"Training failed for {model_name}: {str(e)}")
|
@@ -572,7 +755,7 @@ class RobustModelTrainer:
|
|
572 |
return results
|
573 |
|
574 |
def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
|
575 |
-
"""Select the best performing model"""
|
576 |
|
577 |
if self.progress_tracker:
|
578 |
self.progress_tracker.update("Selecting best model")
|
@@ -586,8 +769,14 @@ class RobustModelTrainer:
|
|
586 |
if 'error' in result:
|
587 |
continue
|
588 |
|
589 |
-
#
|
590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
|
592 |
if f1_score > best_score:
|
593 |
best_score = f1_score
|
@@ -598,12 +787,11 @@ class RobustModelTrainer:
|
|
598 |
if best_model_name is None:
|
599 |
raise ValueError("No models trained successfully")
|
600 |
|
601 |
-
logger.info(
|
602 |
-
f"Best model: {best_model_name} with F1 score: {best_score:.4f}")
|
603 |
return best_model_name, best_model, best_metrics
|
604 |
|
605 |
-
def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
|
606 |
-
"""Save model artifacts and metadata with
|
607 |
try:
|
608 |
if self.progress_tracker:
|
609 |
self.progress_tracker.update("Saving model")
|
@@ -637,7 +825,10 @@ class RobustModelTrainer:
|
|
637 |
# Generate data hash
|
638 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
639 |
|
640 |
-
#
|
|
|
|
|
|
|
641 |
metadata = {
|
642 |
'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
643 |
'model_type': model_name,
|
@@ -648,8 +839,6 @@ class RobustModelTrainer:
|
|
648 |
'test_recall': metrics['recall'],
|
649 |
'test_roc_auc': metrics['roc_auc'],
|
650 |
'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
|
651 |
-
'cv_score_mean': metrics.get('cv_scores', {}).get('mean', 'Unknown'),
|
652 |
-
'cv_score_std': metrics.get('cv_scores', {}).get('std', 'Unknown'),
|
653 |
'timestamp': datetime.now().isoformat(),
|
654 |
'training_config': {
|
655 |
'test_size': self.test_size,
|
@@ -659,16 +848,51 @@ class RobustModelTrainer:
|
|
659 |
'feature_selection_k': self.feature_selection_k
|
660 |
}
|
661 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
|
663 |
# Save metadata with error handling
|
664 |
try:
|
665 |
with open(self.metadata_path, 'w') as f:
|
666 |
json.dump(metadata, f, indent=2)
|
667 |
-
logger.info(f"✅ Saved metadata to {self.metadata_path}")
|
668 |
except Exception as e:
|
669 |
logger.warning(f"Could not save metadata: {e}")
|
670 |
|
671 |
-
logger.info(f"✅ Model artifacts saved successfully")
|
672 |
return True
|
673 |
|
674 |
except Exception as e:
|
@@ -683,9 +907,9 @@ class RobustModelTrainer:
|
|
683 |
return False
|
684 |
|
685 |
def train_model(self, data_path: str = None) -> Tuple[bool, str]:
|
686 |
-
"""Main training function with comprehensive pipeline"""
|
687 |
try:
|
688 |
-
logger.info("Starting model training
|
689 |
|
690 |
# Override data path if provided
|
691 |
if data_path:
|
@@ -703,16 +927,17 @@ class RobustModelTrainer:
|
|
703 |
cv_folds=self.cv_folds
|
704 |
)
|
705 |
|
706 |
-
print(f"\n📊 Training Configuration:")
|
707 |
print(f"Dataset size: {len(df)} samples")
|
|
|
708 |
print(f"Estimated time: {time_estimate['total_formatted']}")
|
709 |
print(f"Models to train: {len(self.models)}")
|
710 |
-
print(f"
|
711 |
print()
|
712 |
|
713 |
-
# Setup progress tracker
|
714 |
-
total_steps = 4 + (len(self.models) *
|
715 |
-
self.progress_tracker = ProgressTracker(total_steps, "Training Progress")
|
716 |
|
717 |
# Prepare data
|
718 |
X = df['text'].values
|
@@ -743,28 +968,35 @@ class RobustModelTrainer:
|
|
743 |
|
744 |
# Additional validation for very small datasets
|
745 |
if len(X_train) < 3:
|
746 |
-
logger.warning(f"Very small training set: {len(X_train)} samples.
|
747 |
if len(X_test) < 1:
|
748 |
return False, "Cannot create test set. Dataset too small."
|
749 |
|
750 |
-
# Train and evaluate models
|
751 |
-
results = self.train_and_evaluate_models(
|
752 |
-
X_train, X_test, y_train, y_test)
|
753 |
|
754 |
# Select best model
|
755 |
best_model_name, best_model, best_metrics = self.select_best_model(results)
|
756 |
|
757 |
-
# Save model artifacts
|
758 |
-
if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
|
759 |
return False, "Failed to save model artifacts"
|
760 |
|
761 |
# Finish progress tracking
|
762 |
self.progress_tracker.finish()
|
763 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
success_message = (
|
765 |
-
f"
|
766 |
f"Best model: {best_model_name} "
|
767 |
-
f"(F1: {best_metrics['f1']:.4f}, Accuracy: {best_metrics['accuracy']:.4f})"
|
768 |
)
|
769 |
|
770 |
logger.info(success_message)
|
@@ -773,23 +1005,29 @@ class RobustModelTrainer:
|
|
773 |
except Exception as e:
|
774 |
if self.progress_tracker:
|
775 |
print() # New line after progress bar
|
776 |
-
error_message = f"
|
777 |
logger.error(error_message)
|
778 |
return False, error_message
|
779 |
|
780 |
|
781 |
def main():
|
782 |
-
"""Main execution function"""
|
783 |
import argparse
|
784 |
|
785 |
# Parse command line arguments
|
786 |
-
parser = argparse.ArgumentParser(description='Train fake news detection model')
|
787 |
parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
|
788 |
parser.add_argument('--config_path', type=str, help='Path to training configuration JSON file')
|
|
|
789 |
args = parser.parse_args()
|
790 |
|
791 |
trainer = RobustModelTrainer()
|
792 |
|
|
|
|
|
|
|
|
|
|
|
793 |
# Load custom configuration if provided
|
794 |
if args.config_path and Path(args.config_path).exists():
|
795 |
try:
|
@@ -799,6 +1037,7 @@ def main():
|
|
799 |
# Apply configuration
|
800 |
trainer.test_size = config.get('test_size', trainer.test_size)
|
801 |
trainer.cv_folds = config.get('cv_folds', trainer.cv_folds)
|
|
|
802 |
trainer.max_features = config.get('max_features', trainer.max_features)
|
803 |
trainer.ngram_range = tuple(config.get('ngram_range', trainer.ngram_range))
|
804 |
|
@@ -811,7 +1050,7 @@ def main():
|
|
811 |
# Update feature selection based on max_features
|
812 |
trainer.feature_selection_k = min(trainer.feature_selection_k, trainer.max_features)
|
813 |
|
814 |
-
logger.info(f"Applied custom configuration: {config}")
|
815 |
|
816 |
except Exception as e:
|
817 |
logger.warning(f"Failed to load configuration: {e}, using defaults")
|
|
|
1 |
+
# File: model/train.py (MODIFIED)
|
2 |
+
# Enhanced version with comprehensive cross-validation implementation
|
3 |
+
|
4 |
import seaborn as sns
|
5 |
import matplotlib.pyplot as plt
|
6 |
from sklearn.feature_selection import SelectKBest, chi2
|
|
|
13 |
)
|
14 |
from sklearn.model_selection import (
|
15 |
train_test_split, cross_val_score, GridSearchCV,
|
16 |
+
StratifiedKFold, validation_curve, cross_validate
|
17 |
)
|
18 |
from sklearn.ensemble import RandomForestClassifier
|
19 |
from sklearn.linear_model import LogisticRegression
|
|
|
29 |
import os
|
30 |
import time
|
31 |
from datetime import datetime, timedelta
|
32 |
+
from typing import Dict, Tuple, Optional, Any, List
|
33 |
import warnings
|
34 |
import re
|
35 |
warnings.filterwarnings('ignore')
|
|
|
146 |
print(f"\n{self.description} completed in {timedelta(seconds=int(total_time))}")
|
147 |
|
148 |
|
149 |
+
def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_folds: int = 5) -> Dict:
|
150 |
"""Estimate training time based on dataset characteristics"""
|
151 |
|
152 |
# Base time estimates (in seconds) based on empirical testing
|
|
|
176 |
estimates['vectorization'] = base_times['vectorization']
|
177 |
estimates['feature_selection'] = base_times['feature_selection']
|
178 |
|
179 |
+
# Model training (now includes CV)
|
180 |
for model_name, multiplier in tuning_multipliers.items():
|
181 |
model_time = base_times['simple_training'] * multiplier * cv_multiplier
|
182 |
estimates[f'{model_name}_training'] = model_time
|
183 |
estimates[f'{model_name}_evaluation'] = base_times['evaluation']
|
184 |
|
185 |
+
# Cross-validation overhead
|
186 |
+
estimates['cross_validation'] = base_times['simple_training'] * cv_folds * 0.5
|
187 |
+
|
188 |
# Model saving
|
189 |
estimates['model_saving'] = 1.0
|
190 |
|
|
|
204 |
}
|
205 |
|
206 |
|
207 |
+
class CrossValidationManager:
|
208 |
+
"""Advanced cross-validation management with comprehensive metrics"""
|
209 |
+
|
210 |
+
def __init__(self, cv_folds: int = 5, random_state: int = 42):
|
211 |
+
self.cv_folds = cv_folds
|
212 |
+
self.random_state = random_state
|
213 |
+
self.cv_results = {}
|
214 |
+
|
215 |
+
def create_cv_strategy(self, X, y) -> StratifiedKFold:
|
216 |
+
"""Create appropriate CV strategy based on data characteristics"""
|
217 |
+
# Calculate appropriate CV folds for small datasets
|
218 |
+
n_samples = len(X)
|
219 |
+
min_samples_per_fold = 3 # Minimum samples per fold
|
220 |
+
max_folds = n_samples // min_samples_per_fold
|
221 |
+
|
222 |
+
# Adjust folds based on data size and class distribution
|
223 |
+
unique_classes = np.unique(y)
|
224 |
+
min_class_count = min([np.sum(y == cls) for cls in unique_classes])
|
225 |
+
|
226 |
+
# Ensure each fold has at least one sample from each class
|
227 |
+
max_folds_by_class = min_class_count
|
228 |
+
|
229 |
+
actual_folds = max(2, min(self.cv_folds, max_folds, max_folds_by_class))
|
230 |
+
|
231 |
+
logger.info(f"Using {actual_folds} CV folds (requested: {self.cv_folds})")
|
232 |
+
|
233 |
+
return StratifiedKFold(
|
234 |
+
n_splits=actual_folds,
|
235 |
+
shuffle=True,
|
236 |
+
random_state=self.random_state
|
237 |
+
)
|
238 |
+
|
239 |
+
def perform_cross_validation(self, pipeline, X, y, cv_strategy=None) -> Dict:
|
240 |
+
"""Perform comprehensive cross-validation with multiple metrics"""
|
241 |
+
|
242 |
+
if cv_strategy is None:
|
243 |
+
cv_strategy = self.create_cv_strategy(X, y)
|
244 |
+
|
245 |
+
logger.info(f"Starting cross-validation with {cv_strategy.n_splits} folds...")
|
246 |
+
|
247 |
+
# Define scoring metrics
|
248 |
+
scoring_metrics = {
|
249 |
+
'accuracy': 'accuracy',
|
250 |
+
'precision': 'precision_weighted',
|
251 |
+
'recall': 'recall_weighted',
|
252 |
+
'f1': 'f1_weighted',
|
253 |
+
'roc_auc': 'roc_auc'
|
254 |
+
}
|
255 |
+
|
256 |
+
try:
|
257 |
+
# Perform cross-validation
|
258 |
+
cv_scores = cross_validate(
|
259 |
+
pipeline, X, y,
|
260 |
+
cv=cv_strategy,
|
261 |
+
scoring=scoring_metrics,
|
262 |
+
return_train_score=True,
|
263 |
+
n_jobs=1, # Use single job for stability
|
264 |
+
verbose=0
|
265 |
+
)
|
266 |
+
|
267 |
+
# Process results
|
268 |
+
cv_results = {
|
269 |
+
'n_splits': cv_strategy.n_splits,
|
270 |
+
'test_scores': {},
|
271 |
+
'train_scores': {},
|
272 |
+
'fold_results': []
|
273 |
+
}
|
274 |
+
|
275 |
+
# Calculate statistics for each metric
|
276 |
+
for metric_name in scoring_metrics.keys():
|
277 |
+
test_key = f'test_{metric_name}'
|
278 |
+
train_key = f'train_{metric_name}'
|
279 |
+
|
280 |
+
if test_key in cv_scores:
|
281 |
+
test_scores = cv_scores[test_key]
|
282 |
+
cv_results['test_scores'][metric_name] = {
|
283 |
+
'mean': float(np.mean(test_scores)),
|
284 |
+
'std': float(np.std(test_scores)),
|
285 |
+
'min': float(np.min(test_scores)),
|
286 |
+
'max': float(np.max(test_scores)),
|
287 |
+
'scores': test_scores.tolist()
|
288 |
+
}
|
289 |
+
|
290 |
+
if train_key in cv_scores:
|
291 |
+
train_scores = cv_scores[train_key]
|
292 |
+
cv_results['train_scores'][metric_name] = {
|
293 |
+
'mean': float(np.mean(train_scores)),
|
294 |
+
'std': float(np.std(train_scores)),
|
295 |
+
'min': float(np.min(train_scores)),
|
296 |
+
'max': float(np.max(train_scores)),
|
297 |
+
'scores': train_scores.tolist()
|
298 |
+
}
|
299 |
+
|
300 |
+
# Store individual fold results
|
301 |
+
for fold_idx in range(cv_strategy.n_splits):
|
302 |
+
fold_result = {
|
303 |
+
'fold': fold_idx + 1,
|
304 |
+
'test_scores': {},
|
305 |
+
'train_scores': {}
|
306 |
+
}
|
307 |
+
|
308 |
+
for metric_name in scoring_metrics.keys():
|
309 |
+
test_key = f'test_{metric_name}'
|
310 |
+
train_key = f'train_{metric_name}'
|
311 |
+
|
312 |
+
if test_key in cv_scores:
|
313 |
+
fold_result['test_scores'][metric_name] = float(cv_scores[test_key][fold_idx])
|
314 |
+
if train_key in cv_scores:
|
315 |
+
fold_result['train_scores'][metric_name] = float(cv_scores[train_key][fold_idx])
|
316 |
+
|
317 |
+
cv_results['fold_results'].append(fold_result)
|
318 |
+
|
319 |
+
# Calculate overfitting indicators
|
320 |
+
if 'accuracy' in cv_results['test_scores'] and 'accuracy' in cv_results['train_scores']:
|
321 |
+
train_mean = cv_results['train_scores']['accuracy']['mean']
|
322 |
+
test_mean = cv_results['test_scores']['accuracy']['mean']
|
323 |
+
cv_results['overfitting_score'] = float(train_mean - test_mean)
|
324 |
+
|
325 |
+
# Calculate stability metrics
|
326 |
+
if 'accuracy' in cv_results['test_scores']:
|
327 |
+
test_std = cv_results['test_scores']['accuracy']['std']
|
328 |
+
test_mean = cv_results['test_scores']['accuracy']['mean']
|
329 |
+
cv_results['stability_score'] = float(1 - (test_std / test_mean)) if test_mean > 0 else 0
|
330 |
+
|
331 |
+
logger.info(f"Cross-validation completed successfully")
|
332 |
+
logger.info(f"Mean test accuracy: {cv_results['test_scores'].get('accuracy', {}).get('mean', 'N/A'):.4f}")
|
333 |
+
logger.info(f"Mean test F1: {cv_results['test_scores'].get('f1', {}).get('mean', 'N/A'):.4f}")
|
334 |
+
|
335 |
+
return cv_results
|
336 |
+
|
337 |
+
except Exception as e:
|
338 |
+
logger.error(f"Cross-validation failed: {e}")
|
339 |
+
return {
|
340 |
+
'error': str(e),
|
341 |
+
'n_splits': cv_strategy.n_splits if cv_strategy else self.cv_folds,
|
342 |
+
'fallback': True
|
343 |
+
}
|
344 |
+
|
345 |
+
def compare_cv_results(self, results1: Dict, results2: Dict, metric: str = 'f1') -> Dict:
|
346 |
+
"""Compare cross-validation results between two models"""
|
347 |
+
|
348 |
+
try:
|
349 |
+
if 'error' in results1 or 'error' in results2:
|
350 |
+
return {'error': 'Cannot compare results with errors'}
|
351 |
+
|
352 |
+
scores1 = results1['test_scores'][metric]['scores']
|
353 |
+
scores2 = results2['test_scores'][metric]['scores']
|
354 |
+
|
355 |
+
# Paired t-test
|
356 |
+
from scipy import stats
|
357 |
+
t_stat, p_value = stats.ttest_rel(scores1, scores2)
|
358 |
+
|
359 |
+
comparison = {
|
360 |
+
'metric': metric,
|
361 |
+
'model1_mean': results1['test_scores'][metric]['mean'],
|
362 |
+
'model2_mean': results2['test_scores'][metric]['mean'],
|
363 |
+
'model1_std': results1['test_scores'][metric]['std'],
|
364 |
+
'model2_std': results2['test_scores'][metric]['std'],
|
365 |
+
'difference': results2['test_scores'][metric]['mean'] - results1['test_scores'][metric]['mean'],
|
366 |
+
'paired_ttest': {
|
367 |
+
't_statistic': float(t_stat),
|
368 |
+
'p_value': float(p_value),
|
369 |
+
'significant': p_value < 0.05
|
370 |
+
},
|
371 |
+
'effect_size': float(abs(t_stat) / np.sqrt(len(scores1))) if len(scores1) > 0 else 0
|
372 |
+
}
|
373 |
+
|
374 |
+
return comparison
|
375 |
+
|
376 |
+
except Exception as e:
|
377 |
+
logger.error(f"CV comparison failed: {e}")
|
378 |
+
return {'error': str(e)}
|
379 |
+
|
380 |
+
|
381 |
class RobustModelTrainer:
|
382 |
+
"""Production-ready model trainer with comprehensive cross-validation"""
|
383 |
|
384 |
def __init__(self):
|
385 |
self.setup_paths()
|
386 |
self.setup_training_config()
|
387 |
self.setup_models()
|
388 |
self.progress_tracker = None
|
389 |
+
self.cv_manager = CrossValidationManager()
|
390 |
|
391 |
def setup_paths(self):
|
392 |
"""Setup all necessary paths with proper permissions"""
|
|
|
413 |
self.evaluation_path = self.results_dir / "evaluation_results.json"
|
414 |
|
415 |
def setup_training_config(self):
|
416 |
+
"""Setup training configuration with CV parameters"""
|
417 |
self.test_size = 0.2
|
418 |
self.validation_size = 0.1
|
419 |
self.random_state = 42
|
420 |
+
self.cv_folds = 5 # Primary CV folds
|
421 |
self.max_features = 5000 # Reduced for speed
|
422 |
self.min_df = 1 # More lenient for small datasets
|
423 |
self.max_df = 0.95
|
|
|
493 |
if len(unique_labels) < 2:
|
494 |
return False, None, f"Need at least 2 classes, found: {unique_labels}"
|
495 |
|
496 |
+
# Check minimum sample size for CV
|
497 |
+
min_samples_for_cv = self.cv_folds * 2 # At least 2 samples per fold
|
498 |
+
if len(df) < min_samples_for_cv:
|
499 |
+
logger.warning(f"Dataset size ({len(df)}) is small for {self.cv_folds}-fold CV")
|
500 |
+
# Adjust CV folds for small datasets
|
501 |
+
self.cv_manager.cv_folds = max(2, len(df) // 3)
|
502 |
+
logger.info(f"Adjusted CV folds to {self.cv_manager.cv_folds}")
|
503 |
|
504 |
# Check class balance
|
505 |
label_counts = df['label'].value_counts()
|
|
|
559 |
return pipeline
|
560 |
|
561 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
562 |
+
"""Comprehensive model evaluation with cross-validation integration"""
|
563 |
|
564 |
if self.progress_tracker:
|
565 |
self.progress_tracker.update("Evaluating model")
|
|
|
581 |
cm = confusion_matrix(y_test, y_pred)
|
582 |
metrics['confusion_matrix'] = cm.tolist()
|
583 |
|
584 |
+
# Cross-validation on full dataset
|
585 |
+
if X_train is not None and y_train is not None:
|
586 |
+
# Combine train and test for full dataset CV
|
587 |
+
X_full = np.concatenate([X_train, X_test])
|
588 |
+
y_full = np.concatenate([y_train, y_test])
|
589 |
+
|
590 |
+
logger.info("Performing cross-validation on full dataset...")
|
591 |
+
cv_results = self.cv_manager.perform_cross_validation(model, X_full, y_full)
|
592 |
+
metrics['cross_validation'] = cv_results
|
593 |
+
|
594 |
+
# Log CV results
|
595 |
+
if 'test_scores' in cv_results and 'f1' in cv_results['test_scores']:
|
596 |
+
cv_f1_mean = cv_results['test_scores']['f1']['mean']
|
597 |
+
cv_f1_std = cv_results['test_scores']['f1']['std']
|
598 |
+
logger.info(f"CV F1 Score: {cv_f1_mean:.4f} (±{cv_f1_std:.4f})")
|
599 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
# Training accuracy for overfitting detection
|
601 |
try:
|
602 |
if X_train is not None and y_train is not None:
|
|
|
610 |
|
611 |
return metrics
|
612 |
|
613 |
+
def hyperparameter_tuning_with_cv(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
|
614 |
+
"""Perform hyperparameter tuning with nested cross-validation"""
|
615 |
|
616 |
if self.progress_tracker:
|
617 |
+
self.progress_tracker.update(f"Tuning {model_name} with CV")
|
618 |
|
619 |
try:
|
620 |
# Set the model in the pipeline
|
|
|
624 |
if len(X_train) < 20:
|
625 |
logger.info(f"Skipping hyperparameter tuning for {model_name} due to small dataset")
|
626 |
pipeline.fit(X_train, y_train)
|
627 |
+
|
628 |
+
# Still perform CV evaluation
|
629 |
+
cv_results = self.cv_manager.perform_cross_validation(pipeline, X_train, y_train)
|
630 |
+
|
631 |
return pipeline, {
|
632 |
'best_params': 'default_parameters',
|
633 |
+
'best_score': cv_results.get('test_scores', {}).get('f1', {}).get('mean', 'not_calculated'),
|
634 |
'best_estimator': pipeline,
|
635 |
+
'cross_validation': cv_results,
|
636 |
'note': 'Hyperparameter tuning skipped for small dataset'
|
637 |
}
|
638 |
|
639 |
# Get parameter grid
|
640 |
param_grid = self.models[model_name]['param_grid']
|
641 |
|
642 |
+
# Create CV strategy
|
643 |
+
cv_strategy = self.cv_manager.create_cv_strategy(X_train, y_train)
|
|
|
|
|
|
|
644 |
|
645 |
+
# Create GridSearchCV with nested cross-validation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
grid_search = GridSearchCV(
|
647 |
pipeline,
|
648 |
param_grid,
|
649 |
+
cv=cv_strategy,
|
|
|
650 |
scoring='f1_weighted',
|
651 |
+
n_jobs=1, # Single job for stability
|
652 |
+
verbose=0, # Reduce verbosity for speed
|
653 |
+
return_train_score=True # For overfitting analysis
|
654 |
)
|
655 |
|
656 |
# Fit grid search
|
657 |
+
logger.info(f"Starting hyperparameter tuning for {model_name}...")
|
658 |
grid_search.fit(X_train, y_train)
|
659 |
|
660 |
+
# Perform additional CV on best model
|
661 |
+
logger.info(f"Performing final CV evaluation for {model_name}...")
|
662 |
+
best_cv_results = self.cv_manager.perform_cross_validation(
|
663 |
+
grid_search.best_estimator_, X_train, y_train, cv_strategy
|
664 |
+
)
|
665 |
+
|
666 |
# Extract results
|
667 |
tuning_results = {
|
668 |
'best_params': grid_search.best_params_,
|
669 |
'best_score': float(grid_search.best_score_),
|
670 |
'best_estimator': grid_search.best_estimator_,
|
671 |
+
'cv_folds_used': cv_strategy.n_splits,
|
672 |
+
'cross_validation': best_cv_results,
|
673 |
+
'grid_search_results': {
|
674 |
'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
|
675 |
'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
|
676 |
+
'mean_train_scores': grid_search.cv_results_.get('mean_train_score', []).tolist() if 'mean_train_score' in grid_search.cv_results_ else [],
|
677 |
'params': grid_search.cv_results_['params']
|
678 |
}
|
679 |
}
|
680 |
|
681 |
logger.info(f"Hyperparameter tuning completed for {model_name}")
|
682 |
+
logger.info(f"Best CV score: {grid_search.best_score_:.4f}")
|
683 |
logger.info(f"Best params: {grid_search.best_params_}")
|
684 |
+
|
685 |
+
if 'test_scores' in best_cv_results and 'f1' in best_cv_results['test_scores']:
|
686 |
+
final_f1 = best_cv_results['test_scores']['f1']['mean']
|
687 |
+
final_f1_std = best_cv_results['test_scores']['f1']['std']
|
688 |
+
logger.info(f"Final CV F1: {final_f1:.4f} (±{final_f1_std:.4f})")
|
689 |
|
690 |
return grid_search.best_estimator_, tuning_results
|
691 |
|
|
|
695 |
try:
|
696 |
pipeline.set_params(model=self.models[model_name]['model'])
|
697 |
pipeline.fit(X_train, y_train)
|
698 |
+
|
699 |
+
# Perform basic CV
|
700 |
+
cv_results = self.cv_manager.perform_cross_validation(pipeline, X_train, y_train)
|
701 |
+
|
702 |
+
return pipeline, {
|
703 |
+
'error': str(e),
|
704 |
+
'fallback': 'simple_training',
|
705 |
+
'cross_validation': cv_results
|
706 |
+
}
|
707 |
except Exception as e2:
|
708 |
logger.error(f"Fallback training also failed for {model_name}: {str(e2)}")
|
709 |
raise Exception(f"Both hyperparameter tuning and fallback training failed: {str(e)} | {str(e2)}")
|
710 |
|
711 |
def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
|
712 |
+
"""Train and evaluate multiple models with comprehensive CV"""
|
713 |
|
714 |
results = {}
|
715 |
|
716 |
for model_name in self.models.keys():
|
717 |
+
logger.info(f"Training {model_name} with cross-validation...")
|
718 |
|
719 |
try:
|
720 |
# Create pipeline
|
721 |
pipeline = self.create_preprocessing_pipeline()
|
722 |
|
723 |
+
# Hyperparameter tuning with CV
|
724 |
+
best_model, tuning_results = self.hyperparameter_tuning_with_cv(
|
725 |
pipeline, X_train, y_train, model_name
|
726 |
)
|
727 |
|
728 |
+
# Comprehensive evaluation (includes additional CV)
|
729 |
evaluation_metrics = self.comprehensive_evaluation(
|
730 |
best_model, X_test, y_test, X_train, y_train
|
731 |
)
|
|
|
738 |
'training_time': datetime.now().isoformat()
|
739 |
}
|
740 |
|
741 |
+
# Log results
|
742 |
+
test_f1 = evaluation_metrics['f1']
|
743 |
+
cv_results = evaluation_metrics.get('cross_validation', {})
|
744 |
+
cv_f1_mean = cv_results.get('test_scores', {}).get('f1', {}).get('mean', 'N/A')
|
745 |
+
cv_f1_std = cv_results.get('test_scores', {}).get('f1', {}).get('std', 'N/A')
|
746 |
+
|
747 |
+
logger.info(f"Model {model_name} - Test F1: {test_f1:.4f}, "
|
748 |
+
f"CV F1: {cv_f1_mean:.4f if cv_f1_mean != 'N/A' else cv_f1_mean} "
|
749 |
+
f"(±{cv_f1_std:.4f if cv_f1_std != 'N/A' else cv_f1_std})")
|
750 |
|
751 |
except Exception as e:
|
752 |
logger.error(f"Training failed for {model_name}: {str(e)}")
|
|
|
755 |
return results
|
756 |
|
757 |
def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
|
758 |
+
"""Select the best performing model based on CV results"""
|
759 |
|
760 |
if self.progress_tracker:
|
761 |
self.progress_tracker.update("Selecting best model")
|
|
|
769 |
if 'error' in result:
|
770 |
continue
|
771 |
|
772 |
+
# Prioritize CV F1 score if available, fallback to test F1
|
773 |
+
cv_results = result['evaluation_metrics'].get('cross_validation', {})
|
774 |
+
if 'test_scores' in cv_results and 'f1' in cv_results['test_scores']:
|
775 |
+
f1_score = cv_results['test_scores']['f1']['mean']
|
776 |
+
score_type = "CV F1"
|
777 |
+
else:
|
778 |
+
f1_score = result['evaluation_metrics']['f1']
|
779 |
+
score_type = "Test F1"
|
780 |
|
781 |
if f1_score > best_score:
|
782 |
best_score = f1_score
|
|
|
787 |
if best_model_name is None:
|
788 |
raise ValueError("No models trained successfully")
|
789 |
|
790 |
+
logger.info(f"Best model: {best_model_name} with {score_type} score: {best_score:.4f}")
|
|
|
791 |
return best_model_name, best_model, best_metrics
|
792 |
|
793 |
+
def save_model_artifacts(self, model, model_name: str, metrics: Dict, results: Dict) -> bool:
|
794 |
+
"""Save model artifacts and enhanced metadata with CV results"""
|
795 |
try:
|
796 |
if self.progress_tracker:
|
797 |
self.progress_tracker.update("Saving model")
|
|
|
825 |
# Generate data hash
|
826 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
827 |
|
828 |
+
# Extract CV results
|
829 |
+
cv_results = metrics.get('cross_validation', {})
|
830 |
+
|
831 |
+
# Create enhanced metadata with CV information
|
832 |
metadata = {
|
833 |
'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
834 |
'model_type': model_name,
|
|
|
839 |
'test_recall': metrics['recall'],
|
840 |
'test_roc_auc': metrics['roc_auc'],
|
841 |
'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
|
|
|
|
|
842 |
'timestamp': datetime.now().isoformat(),
|
843 |
'training_config': {
|
844 |
'test_size': self.test_size,
|
|
|
848 |
'feature_selection_k': self.feature_selection_k
|
849 |
}
|
850 |
}
|
851 |
+
|
852 |
+
# Add comprehensive CV results to metadata
|
853 |
+
if cv_results and 'test_scores' in cv_results:
|
854 |
+
metadata['cross_validation'] = {
|
855 |
+
'n_splits': cv_results.get('n_splits', self.cv_folds),
|
856 |
+
'test_scores': cv_results['test_scores'],
|
857 |
+
'train_scores': cv_results.get('train_scores', {}),
|
858 |
+
'overfitting_score': cv_results.get('overfitting_score', 'Unknown'),
|
859 |
+
'stability_score': cv_results.get('stability_score', 'Unknown'),
|
860 |
+
'individual_fold_results': cv_results.get('fold_results', [])
|
861 |
+
}
|
862 |
+
|
863 |
+
# Add summary statistics
|
864 |
+
if 'f1' in cv_results['test_scores']:
|
865 |
+
metadata['cv_f1_mean'] = cv_results['test_scores']['f1']['mean']
|
866 |
+
metadata['cv_f1_std'] = cv_results['test_scores']['f1']['std']
|
867 |
+
metadata['cv_f1_min'] = cv_results['test_scores']['f1']['min']
|
868 |
+
metadata['cv_f1_max'] = cv_results['test_scores']['f1']['max']
|
869 |
+
|
870 |
+
if 'accuracy' in cv_results['test_scores']:
|
871 |
+
metadata['cv_accuracy_mean'] = cv_results['test_scores']['accuracy']['mean']
|
872 |
+
metadata['cv_accuracy_std'] = cv_results['test_scores']['accuracy']['std']
|
873 |
+
|
874 |
+
# Add model comparison results if available
|
875 |
+
if len(results) > 1:
|
876 |
+
model_comparison = {}
|
877 |
+
for other_model_name, other_result in results.items():
|
878 |
+
if other_model_name != model_name and 'error' not in other_result:
|
879 |
+
other_cv = other_result['evaluation_metrics'].get('cross_validation', {})
|
880 |
+
if cv_results and other_cv:
|
881 |
+
comparison = self.cv_manager.compare_cv_results(cv_results, other_cv)
|
882 |
+
model_comparison[other_model_name] = comparison
|
883 |
+
|
884 |
+
if model_comparison:
|
885 |
+
metadata['model_comparison'] = model_comparison
|
886 |
|
887 |
# Save metadata with error handling
|
888 |
try:
|
889 |
with open(self.metadata_path, 'w') as f:
|
890 |
json.dump(metadata, f, indent=2)
|
891 |
+
logger.info(f"✅ Saved enhanced metadata to {self.metadata_path}")
|
892 |
except Exception as e:
|
893 |
logger.warning(f"Could not save metadata: {e}")
|
894 |
|
895 |
+
logger.info(f"✅ Model artifacts saved successfully with CV results")
|
896 |
return True
|
897 |
|
898 |
except Exception as e:
|
|
|
907 |
return False
|
908 |
|
909 |
def train_model(self, data_path: str = None) -> Tuple[bool, str]:
|
910 |
+
"""Main training function with comprehensive CV pipeline"""
|
911 |
try:
|
912 |
+
logger.info("Starting enhanced model training with cross-validation...")
|
913 |
|
914 |
# Override data path if provided
|
915 |
if data_path:
|
|
|
927 |
cv_folds=self.cv_folds
|
928 |
)
|
929 |
|
930 |
+
print(f"\n📊 Enhanced Training Configuration:")
|
931 |
print(f"Dataset size: {len(df)} samples")
|
932 |
+
print(f"Cross-validation folds: {self.cv_folds}")
|
933 |
print(f"Estimated time: {time_estimate['total_formatted']}")
|
934 |
print(f"Models to train: {len(self.models)}")
|
935 |
+
print(f"Hyperparameter tuning: Enabled")
|
936 |
print()
|
937 |
|
938 |
+
# Setup progress tracker (increased steps for CV)
|
939 |
+
total_steps = 4 + (len(self.models) * 3) + 1 # Load, split, 3*models (tune+cv+eval), select, save
|
940 |
+
self.progress_tracker = ProgressTracker(total_steps, "CV Training Progress")
|
941 |
|
942 |
# Prepare data
|
943 |
X = df['text'].values
|
|
|
968 |
|
969 |
# Additional validation for very small datasets
|
970 |
if len(X_train) < 3:
|
971 |
+
logger.warning(f"Very small training set: {len(X_train)} samples. CV results may be unreliable.")
|
972 |
if len(X_test) < 1:
|
973 |
return False, "Cannot create test set. Dataset too small."
|
974 |
|
975 |
+
# Train and evaluate models with CV
|
976 |
+
results = self.train_and_evaluate_models(X_train, X_test, y_train, y_test)
|
|
|
977 |
|
978 |
# Select best model
|
979 |
best_model_name, best_model, best_metrics = self.select_best_model(results)
|
980 |
|
981 |
+
# Save model artifacts with CV results
|
982 |
+
if not self.save_model_artifacts(best_model, best_model_name, best_metrics, results):
|
983 |
return False, "Failed to save model artifacts"
|
984 |
|
985 |
# Finish progress tracking
|
986 |
self.progress_tracker.finish()
|
987 |
|
988 |
+
# Create success message with CV information
|
989 |
+
cv_results = best_metrics.get('cross_validation', {})
|
990 |
+
cv_info = ""
|
991 |
+
if 'test_scores' in cv_results and 'f1' in cv_results['test_scores']:
|
992 |
+
cv_f1_mean = cv_results['test_scores']['f1']['mean']
|
993 |
+
cv_f1_std = cv_results['test_scores']['f1']['std']
|
994 |
+
cv_info = f", CV F1: {cv_f1_mean:.4f} (±{cv_f1_std:.4f})"
|
995 |
+
|
996 |
success_message = (
|
997 |
+
f"Enhanced model training completed successfully. "
|
998 |
f"Best model: {best_model_name} "
|
999 |
+
f"(Test F1: {best_metrics['f1']:.4f}, Test Accuracy: {best_metrics['accuracy']:.4f}{cv_info})"
|
1000 |
)
|
1001 |
|
1002 |
logger.info(success_message)
|
|
|
1005 |
except Exception as e:
|
1006 |
if self.progress_tracker:
|
1007 |
print() # New line after progress bar
|
1008 |
+
error_message = f"Enhanced model training failed: {str(e)}"
|
1009 |
logger.error(error_message)
|
1010 |
return False, error_message
|
1011 |
|
1012 |
|
1013 |
def main():
|
1014 |
+
"""Main execution function with enhanced CV support"""
|
1015 |
import argparse
|
1016 |
|
1017 |
# Parse command line arguments
|
1018 |
+
parser = argparse.ArgumentParser(description='Train fake news detection model with cross-validation')
|
1019 |
parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
|
1020 |
parser.add_argument('--config_path', type=str, help='Path to training configuration JSON file')
|
1021 |
+
parser.add_argument('--cv_folds', type=int, default=5, help='Number of cross-validation folds')
|
1022 |
args = parser.parse_args()
|
1023 |
|
1024 |
trainer = RobustModelTrainer()
|
1025 |
|
1026 |
+
# Apply CV folds from command line
|
1027 |
+
if args.cv_folds:
|
1028 |
+
trainer.cv_folds = args.cv_folds
|
1029 |
+
trainer.cv_manager.cv_folds = args.cv_folds
|
1030 |
+
|
1031 |
# Load custom configuration if provided
|
1032 |
if args.config_path and Path(args.config_path).exists():
|
1033 |
try:
|
|
|
1037 |
# Apply configuration
|
1038 |
trainer.test_size = config.get('test_size', trainer.test_size)
|
1039 |
trainer.cv_folds = config.get('cv_folds', trainer.cv_folds)
|
1040 |
+
trainer.cv_manager.cv_folds = trainer.cv_folds
|
1041 |
trainer.max_features = config.get('max_features', trainer.max_features)
|
1042 |
trainer.ngram_range = tuple(config.get('ngram_range', trainer.ngram_range))
|
1043 |
|
|
|
1050 |
# Update feature selection based on max_features
|
1051 |
trainer.feature_selection_k = min(trainer.feature_selection_k, trainer.max_features)
|
1052 |
|
1053 |
+
logger.info(f"Applied custom configuration with {trainer.cv_folds} CV folds: {config}")
|
1054 |
|
1055 |
except Exception as e:
|
1056 |
logger.warning(f"Failed to load configuration: {e}, using defaults")
|