Commit
·
310a651
1
Parent(s):
9d61526
Update model/train.py
Browse filesCritical Issues in Original train.py:
Single evaluation metric (accuracy only)
No hyperparameter tuning
No cross-validation
Limited preprocessing
No feature selection
No model comparison
No comprehensive evaluation
No overfitting detection
Observational Fix:
Added comprehensive evaluation with multiple metrics
Added hyperparameter tuning with GridSearchCV
Added cross-validation for robust evaluation
Added advanced text preprocessing pipeline
Added feature selection and optimization
Added model comparison (Logistic Regression vs Random Forest)
Added comprehensive evaluation with confusion matrix, ROC curves
Added overfitting detection and model validation
- model/train.py +562 -76
model/train.py
CHANGED
|
@@ -1,87 +1,573 @@
|
|
| 1 |
import pandas as pd
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
from sklearn.linear_model import LogisticRegression
|
| 5 |
-
from sklearn.metrics import accuracy_score
|
| 6 |
-
from sklearn.model_selection import train_test_split
|
| 7 |
-
import joblib
|
| 8 |
import json
|
| 9 |
-
import
|
| 10 |
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def main():
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# Vectorize
|
| 45 |
-
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
|
| 46 |
-
X_train_vec = vectorizer.fit_transform(X_train)
|
| 47 |
-
X_test_vec = vectorizer.transform(X_test)
|
| 48 |
-
|
| 49 |
-
# print('Train/Test Splits Created')
|
| 50 |
-
# print('Starting Model Training')
|
| 51 |
-
|
| 52 |
-
# Train model
|
| 53 |
-
model = LogisticRegression(max_iter=1000)
|
| 54 |
-
model.fit(X_train_vec, y_train)
|
| 55 |
-
|
| 56 |
-
# print('Model Training Completed')
|
| 57 |
-
#print('Model Evaluation Starting!')
|
| 58 |
-
|
| 59 |
-
# Evaluate
|
| 60 |
-
y_pred = model.predict(X_test_vec)
|
| 61 |
-
acc = accuracy_score(y_test, y_pred)
|
| 62 |
-
|
| 63 |
-
# Save model + vectorizer
|
| 64 |
-
joblib.dump(model, MODEL_PATH)
|
| 65 |
-
joblib.dump(vectorizer, VECTORIZER_PATH)
|
| 66 |
-
|
| 67 |
-
# print('Model Evaluation Done')
|
| 68 |
-
# print('Model Saved!')
|
| 69 |
-
|
| 70 |
-
# Save metadata
|
| 71 |
-
metadata = {
|
| 72 |
-
"model_version": f"v1.0",
|
| 73 |
-
"data_version": hash_file(DATA_PATH),
|
| 74 |
-
"train_size": len(X_train),
|
| 75 |
-
"test_size": len(X_test),
|
| 76 |
-
"test_accuracy": round(acc, 4),
|
| 77 |
-
"timestamp": datetime.datetime.now().isoformat()
|
| 78 |
-
}
|
| 79 |
-
with open(METADATA_PATH, 'w') as f:
|
| 80 |
-
json.dump(metadata, f, indent=4)
|
| 81 |
-
|
| 82 |
-
print(f"✅ Model trained and saved.")
|
| 83 |
-
print(f"📊 Test Accuracy: {acc:.4f}")
|
| 84 |
-
print(f"📝 Metadata saved to {METADATA_PATH}")
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
| 87 |
-
main()
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
from pathlib import Path
|
| 4 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
+
import joblib
|
| 7 |
import hashlib
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Dict, Tuple, Optional, Any
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
+
# Scikit-learn imports
|
| 14 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 17 |
+
from sklearn.model_selection import (
|
| 18 |
+
train_test_split, cross_val_score, GridSearchCV,
|
| 19 |
+
StratifiedKFold, validation_curve
|
| 20 |
+
)
|
| 21 |
+
from sklearn.metrics import (
|
| 22 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 23 |
+
roc_auc_score, confusion_matrix, classification_report,
|
| 24 |
+
precision_recall_curve, roc_curve
|
| 25 |
+
)
|
| 26 |
+
from sklearn.pipeline import Pipeline
|
| 27 |
+
from sklearn.preprocessing import FunctionTransformer
|
| 28 |
+
from sklearn.feature_selection import SelectKBest, chi2
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
import seaborn as sns
|
| 31 |
|
| 32 |
+
# Configure logging
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 36 |
+
handlers=[
|
| 37 |
+
logging.FileHandler('/tmp/model_training.log'),
|
| 38 |
+
logging.StreamHandler()
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
|
| 43 |
+
class RobustModelTrainer:
|
| 44 |
+
"""Production-ready model trainer with comprehensive evaluation and validation"""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
self.setup_paths()
|
| 48 |
+
self.setup_training_config()
|
| 49 |
+
self.setup_models()
|
| 50 |
+
|
| 51 |
+
def setup_paths(self):
|
| 52 |
+
"""Setup all necessary paths"""
|
| 53 |
+
self.base_dir = Path("/tmp")
|
| 54 |
+
self.data_dir = self.base_dir / "data"
|
| 55 |
+
self.model_dir = self.base_dir / "model"
|
| 56 |
+
self.results_dir = self.base_dir / "results"
|
| 57 |
+
|
| 58 |
+
# Create directories
|
| 59 |
+
for dir_path in [self.data_dir, self.model_dir, self.results_dir]:
|
| 60 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# File paths
|
| 63 |
+
self.data_path = self.data_dir / "combined_dataset.csv"
|
| 64 |
+
self.model_path = self.model_dir / "model.pkl"
|
| 65 |
+
self.vectorizer_path = self.model_dir / "vectorizer.pkl"
|
| 66 |
+
self.pipeline_path = self.model_dir / "pipeline.pkl"
|
| 67 |
+
self.metadata_path = Path("/tmp/metadata.json")
|
| 68 |
+
self.evaluation_path = self.results_dir / "evaluation_results.json"
|
| 69 |
+
|
| 70 |
+
def setup_training_config(self):
|
| 71 |
+
"""Setup training configuration"""
|
| 72 |
+
self.test_size = 0.2
|
| 73 |
+
self.validation_size = 0.1
|
| 74 |
+
self.random_state = 42
|
| 75 |
+
self.cv_folds = 5
|
| 76 |
+
self.max_features = 10000
|
| 77 |
+
self.min_df = 2
|
| 78 |
+
self.max_df = 0.95
|
| 79 |
+
self.ngram_range = (1, 3)
|
| 80 |
+
self.max_iter = 1000
|
| 81 |
+
self.class_weight = 'balanced'
|
| 82 |
+
self.feature_selection_k = 5000
|
| 83 |
+
|
| 84 |
+
def setup_models(self):
|
| 85 |
+
"""Setup model configurations for comparison"""
|
| 86 |
+
self.models = {
|
| 87 |
+
'logistic_regression': {
|
| 88 |
+
'model': LogisticRegression(
|
| 89 |
+
max_iter=self.max_iter,
|
| 90 |
+
class_weight=self.class_weight,
|
| 91 |
+
random_state=self.random_state
|
| 92 |
+
),
|
| 93 |
+
'param_grid': {
|
| 94 |
+
'model__C': [0.1, 1, 10, 100],
|
| 95 |
+
'model__penalty': ['l2'],
|
| 96 |
+
'model__solver': ['liblinear', 'lbfgs']
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
'random_forest': {
|
| 100 |
+
'model': RandomForestClassifier(
|
| 101 |
+
n_estimators=100,
|
| 102 |
+
class_weight=self.class_weight,
|
| 103 |
+
random_state=self.random_state
|
| 104 |
+
),
|
| 105 |
+
'param_grid': {
|
| 106 |
+
'model__n_estimators': [50, 100, 200],
|
| 107 |
+
'model__max_depth': [10, 20, None],
|
| 108 |
+
'model__min_samples_split': [2, 5, 10]
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def load_and_validate_data(self) -> Tuple[bool, Optional[pd.DataFrame], str]:
|
| 114 |
+
"""Load and validate training data"""
|
| 115 |
+
try:
|
| 116 |
+
logger.info("Loading training data...")
|
| 117 |
+
|
| 118 |
+
if not self.data_path.exists():
|
| 119 |
+
return False, None, f"Data file not found: {self.data_path}"
|
| 120 |
+
|
| 121 |
+
# Load data
|
| 122 |
+
df = pd.read_csv(self.data_path)
|
| 123 |
+
|
| 124 |
+
# Basic validation
|
| 125 |
+
if df.empty:
|
| 126 |
+
return False, None, "Dataset is empty"
|
| 127 |
+
|
| 128 |
+
required_columns = ['text', 'label']
|
| 129 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 130 |
+
if missing_columns:
|
| 131 |
+
return False, None, f"Missing required columns: {missing_columns}"
|
| 132 |
+
|
| 133 |
+
# Remove missing values
|
| 134 |
+
initial_count = len(df)
|
| 135 |
+
df = df.dropna(subset=required_columns)
|
| 136 |
+
if len(df) < initial_count:
|
| 137 |
+
logger.warning(f"Removed {initial_count - len(df)} rows with missing values")
|
| 138 |
+
|
| 139 |
+
# Validate text content
|
| 140 |
+
df = df[df['text'].astype(str).str.len() > 10]
|
| 141 |
+
|
| 142 |
+
# Validate labels
|
| 143 |
+
unique_labels = df['label'].unique()
|
| 144 |
+
if len(unique_labels) < 2:
|
| 145 |
+
return False, None, f"Need at least 2 classes, found: {unique_labels}"
|
| 146 |
+
|
| 147 |
+
# Check minimum sample size
|
| 148 |
+
if len(df) < 100:
|
| 149 |
+
return False, None, f"Insufficient samples for training: {len(df)}"
|
| 150 |
+
|
| 151 |
+
# Check class balance
|
| 152 |
+
label_counts = df['label'].value_counts()
|
| 153 |
+
min_class_ratio = label_counts.min() / label_counts.max()
|
| 154 |
+
if min_class_ratio < 0.1:
|
| 155 |
+
logger.warning(f"Severe class imbalance detected: {min_class_ratio:.3f}")
|
| 156 |
+
|
| 157 |
+
logger.info(f"Data validation successful: {len(df)} samples, {len(unique_labels)} classes")
|
| 158 |
+
logger.info(f"Class distribution: {label_counts.to_dict()}")
|
| 159 |
+
|
| 160 |
+
return True, df, "Data loaded successfully"
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
error_msg = f"Error loading data: {str(e)}"
|
| 164 |
+
logger.error(error_msg)
|
| 165 |
+
return False, None, error_msg
|
| 166 |
+
|
| 167 |
+
def preprocess_text(self, text):
|
| 168 |
+
"""Advanced text preprocessing"""
|
| 169 |
+
import re
|
| 170 |
+
|
| 171 |
+
# Convert to string
|
| 172 |
+
text = str(text)
|
| 173 |
+
|
| 174 |
+
# Remove URLs
|
| 175 |
+
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
|
| 176 |
+
|
| 177 |
+
# Remove email addresses
|
| 178 |
+
text = re.sub(r'\S+@\S+', '', text)
|
| 179 |
+
|
| 180 |
+
# Remove excessive punctuation
|
| 181 |
+
text = re.sub(r'[!]{2,}', '!', text)
|
| 182 |
+
text = re.sub(r'[?]{2,}', '?', text)
|
| 183 |
+
text = re.sub(r'[.]{3,}', '...', text)
|
| 184 |
+
|
| 185 |
+
# Remove non-alphabetic characters except spaces and basic punctuation
|
| 186 |
+
text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
|
| 187 |
+
|
| 188 |
+
# Remove excessive whitespace
|
| 189 |
+
text = re.sub(r'\s+', ' ', text)
|
| 190 |
+
|
| 191 |
+
return text.strip().lower()
|
| 192 |
+
|
| 193 |
+
def create_preprocessing_pipeline(self) -> Pipeline:
|
| 194 |
+
"""Create advanced preprocessing pipeline"""
|
| 195 |
+
# Text preprocessing
|
| 196 |
+
text_preprocessor = FunctionTransformer(
|
| 197 |
+
func=lambda x: [self.preprocess_text(text) for text in x],
|
| 198 |
+
validate=False
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# TF-IDF vectorization
|
| 202 |
+
vectorizer = TfidfVectorizer(
|
| 203 |
+
max_features=self.max_features,
|
| 204 |
+
min_df=self.min_df,
|
| 205 |
+
max_df=self.max_df,
|
| 206 |
+
ngram_range=self.ngram_range,
|
| 207 |
+
stop_words='english',
|
| 208 |
+
sublinear_tf=True,
|
| 209 |
+
norm='l2'
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Feature selection
|
| 213 |
+
feature_selector = SelectKBest(
|
| 214 |
+
score_func=chi2,
|
| 215 |
+
k=self.feature_selection_k
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Create pipeline
|
| 219 |
+
pipeline = Pipeline([
|
| 220 |
+
('preprocess', text_preprocessor),
|
| 221 |
+
('vectorize', vectorizer),
|
| 222 |
+
('feature_select', feature_selector),
|
| 223 |
+
('model', None) # Will be set during training
|
| 224 |
+
])
|
| 225 |
+
|
| 226 |
+
return pipeline
|
| 227 |
+
|
| 228 |
+
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
| 229 |
+
"""Comprehensive model evaluation with multiple metrics"""
|
| 230 |
+
logger.info("Starting comprehensive model evaluation...")
|
| 231 |
+
|
| 232 |
+
# Predictions
|
| 233 |
+
y_pred = model.predict(X_test)
|
| 234 |
+
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
| 235 |
+
|
| 236 |
+
# Basic metrics
|
| 237 |
+
metrics = {
|
| 238 |
+
'accuracy': float(accuracy_score(y_test, y_pred)),
|
| 239 |
+
'precision': float(precision_score(y_test, y_pred, average='weighted')),
|
| 240 |
+
'recall': float(recall_score(y_test, y_pred, average='weighted')),
|
| 241 |
+
'f1': float(f1_score(y_test, y_pred, average='weighted')),
|
| 242 |
+
'roc_auc': float(roc_auc_score(y_test, y_pred_proba))
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
# Confusion matrix
|
| 246 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 247 |
+
metrics['confusion_matrix'] = cm.tolist()
|
| 248 |
+
|
| 249 |
+
# Classification report
|
| 250 |
+
class_report = classification_report(y_test, y_pred, output_dict=True)
|
| 251 |
+
metrics['classification_report'] = class_report
|
| 252 |
+
|
| 253 |
+
# Cross-validation scores if training data provided
|
| 254 |
+
if X_train is not None and y_train is not None:
|
| 255 |
+
try:
|
| 256 |
+
cv_scores = cross_val_score(
|
| 257 |
+
model, X_train, y_train,
|
| 258 |
+
cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
|
| 259 |
+
scoring='f1_weighted'
|
| 260 |
+
)
|
| 261 |
+
metrics['cv_scores'] = {
|
| 262 |
+
'mean': float(cv_scores.mean()),
|
| 263 |
+
'std': float(cv_scores.std()),
|
| 264 |
+
'scores': cv_scores.tolist()
|
| 265 |
+
}
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.warning(f"Cross-validation failed: {e}")
|
| 268 |
+
metrics['cv_scores'] = None
|
| 269 |
+
|
| 270 |
+
# Feature importance (if available)
|
| 271 |
+
try:
|
| 272 |
+
if hasattr(model, 'feature_importances_'):
|
| 273 |
+
feature_importance = model.feature_importances_
|
| 274 |
+
metrics['feature_importance_stats'] = {
|
| 275 |
+
'mean': float(feature_importance.mean()),
|
| 276 |
+
'std': float(feature_importance.std()),
|
| 277 |
+
'top_features': feature_importance.argsort()[-10:][::-1].tolist()
|
| 278 |
+
}
|
| 279 |
+
elif hasattr(model, 'coef_'):
|
| 280 |
+
coefficients = model.coef_[0]
|
| 281 |
+
metrics['coefficient_stats'] = {
|
| 282 |
+
'mean': float(coefficients.mean()),
|
| 283 |
+
'std': float(coefficients.std()),
|
| 284 |
+
'top_positive': coefficients.argsort()[-10:][::-1].tolist(),
|
| 285 |
+
'top_negative': coefficients.argsort()[:10].tolist()
|
| 286 |
+
}
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.warning(f"Feature importance extraction failed: {e}")
|
| 289 |
+
|
| 290 |
+
# Model complexity metrics
|
| 291 |
+
try:
|
| 292 |
+
# Training accuracy for overfitting detection
|
| 293 |
+
if X_train is not None and y_train is not None:
|
| 294 |
+
y_train_pred = model.predict(X_train)
|
| 295 |
+
train_accuracy = accuracy_score(y_train, y_train_pred)
|
| 296 |
+
metrics['train_accuracy'] = float(train_accuracy)
|
| 297 |
+
metrics['overfitting_score'] = float(train_accuracy - metrics['accuracy'])
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.warning(f"Overfitting detection failed: {e}")
|
| 300 |
+
|
| 301 |
+
return metrics
|
| 302 |
+
|
| 303 |
+
def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
|
| 304 |
+
"""Perform hyperparameter tuning with cross-validation"""
|
| 305 |
+
logger.info(f"Starting hyperparameter tuning for {model_name}...")
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
# Set the model in the pipeline
|
| 309 |
+
pipeline.set_params(model=self.models[model_name]['model'])
|
| 310 |
+
|
| 311 |
+
# Get parameter grid
|
| 312 |
+
param_grid = self.models[model_name]['param_grid']
|
| 313 |
+
|
| 314 |
+
# Create GridSearchCV
|
| 315 |
+
grid_search = GridSearchCV(
|
| 316 |
+
pipeline,
|
| 317 |
+
param_grid,
|
| 318 |
+
cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
|
| 319 |
+
scoring='f1_weighted',
|
| 320 |
+
n_jobs=-1,
|
| 321 |
+
verbose=1
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Fit grid search
|
| 325 |
+
grid_search.fit(X_train, y_train)
|
| 326 |
+
|
| 327 |
+
# Extract results
|
| 328 |
+
tuning_results = {
|
| 329 |
+
'best_params': grid_search.best_params_,
|
| 330 |
+
'best_score': float(grid_search.best_score_),
|
| 331 |
+
'best_estimator': grid_search.best_estimator_,
|
| 332 |
+
'cv_results': {
|
| 333 |
+
'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
|
| 334 |
+
'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
|
| 335 |
+
'params': grid_search.cv_results_['params']
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
logger.info(f"Hyperparameter tuning completed for {model_name}")
|
| 340 |
+
logger.info(f"Best score: {grid_search.best_score_:.4f}")
|
| 341 |
+
logger.info(f"Best params: {grid_search.best_params_}")
|
| 342 |
+
|
| 343 |
+
return grid_search.best_estimator_, tuning_results
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"Hyperparameter tuning failed for {model_name}: {str(e)}")
|
| 347 |
+
# Return basic model if tuning fails
|
| 348 |
+
pipeline.set_params(model=self.models[model_name]['model'])
|
| 349 |
+
pipeline.fit(X_train, y_train)
|
| 350 |
+
return pipeline, {'error': str(e)}
|
| 351 |
+
|
| 352 |
+
def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
|
| 353 |
+
"""Train and evaluate multiple models"""
|
| 354 |
+
logger.info("Starting model training and evaluation...")
|
| 355 |
+
|
| 356 |
+
results = {}
|
| 357 |
+
|
| 358 |
+
for model_name in self.models.keys():
|
| 359 |
+
logger.info(f"Training {model_name}...")
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
# Create pipeline
|
| 363 |
+
pipeline = self.create_preprocessing_pipeline()
|
| 364 |
+
|
| 365 |
+
# Hyperparameter tuning
|
| 366 |
+
best_model, tuning_results = self.hyperparameter_tuning(
|
| 367 |
+
pipeline, X_train, y_train, model_name
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Comprehensive evaluation
|
| 371 |
+
evaluation_metrics = self.comprehensive_evaluation(
|
| 372 |
+
best_model, X_test, y_test, X_train, y_train
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Store results
|
| 376 |
+
results[model_name] = {
|
| 377 |
+
'model': best_model,
|
| 378 |
+
'tuning_results': tuning_results,
|
| 379 |
+
'evaluation_metrics': evaluation_metrics,
|
| 380 |
+
'training_time': datetime.now().isoformat()
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
logger.info(f"Model {model_name} - F1: {evaluation_metrics['f1']:.4f}, "
|
| 384 |
+
f"Accuracy: {evaluation_metrics['accuracy']:.4f}")
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.error(f"Training failed for {model_name}: {str(e)}")
|
| 388 |
+
results[model_name] = {'error': str(e)}
|
| 389 |
+
|
| 390 |
+
return results
|
| 391 |
+
|
| 392 |
+
def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
|
| 393 |
+
"""Select the best performing model"""
|
| 394 |
+
logger.info("Selecting best model...")
|
| 395 |
+
|
| 396 |
+
best_model_name = None
|
| 397 |
+
best_model = None
|
| 398 |
+
best_score = -1
|
| 399 |
+
best_metrics = None
|
| 400 |
+
|
| 401 |
+
for model_name, result in results.items():
|
| 402 |
+
if 'error' in result:
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
# Use F1 score as primary metric
|
| 406 |
+
f1_score = result['evaluation_metrics']['f1']
|
| 407 |
+
|
| 408 |
+
if f1_score > best_score:
|
| 409 |
+
best_score = f1_score
|
| 410 |
+
best_model_name = model_name
|
| 411 |
+
best_model = result['model']
|
| 412 |
+
best_metrics = result['evaluation_metrics']
|
| 413 |
+
|
| 414 |
+
if best_model_name is None:
|
| 415 |
+
raise ValueError("No models trained successfully")
|
| 416 |
+
|
| 417 |
+
logger.info(f"Best model: {best_model_name} with F1 score: {best_score:.4f}")
|
| 418 |
+
return best_model_name, best_model, best_metrics
|
| 419 |
+
|
| 420 |
+
def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
|
| 421 |
+
"""Save model artifacts and metadata"""
|
| 422 |
+
try:
|
| 423 |
+
logger.info("Saving model artifacts...")
|
| 424 |
+
|
| 425 |
+
# Save the full pipeline
|
| 426 |
+
joblib.dump(model, self.pipeline_path)
|
| 427 |
+
|
| 428 |
+
# Save individual components for backward compatibility
|
| 429 |
+
joblib.dump(model.named_steps['model'], self.model_path)
|
| 430 |
+
joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
|
| 431 |
+
|
| 432 |
+
# Generate data hash
|
| 433 |
+
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
| 434 |
+
|
| 435 |
+
# Create metadata
|
| 436 |
+
metadata = {
|
| 437 |
+
'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 438 |
+
'model_type': model_name,
|
| 439 |
+
'data_version': data_hash,
|
| 440 |
+
'train_size': metrics.get('train_accuracy', 'Unknown'),
|
| 441 |
+
'test_size': len(metrics.get('confusion_matrix', [[0]])[0]) if 'confusion_matrix' in metrics else 'Unknown',
|
| 442 |
+
'test_accuracy': metrics['accuracy'],
|
| 443 |
+
'test_f1': metrics['f1'],
|
| 444 |
+
'test_precision': metrics['precision'],
|
| 445 |
+
'test_recall': metrics['recall'],
|
| 446 |
+
'test_roc_auc': metrics['roc_auc'],
|
| 447 |
+
'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
|
| 448 |
+
'cv_score_mean': metrics.get('cv_scores', {}).get('mean', 'Unknown'),
|
| 449 |
+
'cv_score_std': metrics.get('cv_scores', {}).get('std', 'Unknown'),
|
| 450 |
+
'timestamp': datetime.now().isoformat(),
|
| 451 |
+
'training_config': {
|
| 452 |
+
'test_size': self.test_size,
|
| 453 |
+
'validation_size': self.validation_size,
|
| 454 |
+
'cv_folds': self.cv_folds,
|
| 455 |
+
'max_features': self.max_features,
|
| 456 |
+
'ngram_range': self.ngram_range,
|
| 457 |
+
'feature_selection_k': self.feature_selection_k
|
| 458 |
+
}
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
# Save metadata
|
| 462 |
+
with open(self.metadata_path, 'w') as f:
|
| 463 |
+
json.dump(metadata, f, indent=2)
|
| 464 |
+
|
| 465 |
+
logger.info(f"Model artifacts saved successfully")
|
| 466 |
+
logger.info(f"Model path: {self.model_path}")
|
| 467 |
+
logger.info(f"Vectorizer path: {self.vectorizer_path}")
|
| 468 |
+
logger.info(f"Pipeline path: {self.pipeline_path}")
|
| 469 |
+
logger.info(f"Metadata path: {self.metadata_path}")
|
| 470 |
+
|
| 471 |
+
return True
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
logger.error(f"Failed to save model artifacts: {str(e)}")
|
| 475 |
+
return False
|
| 476 |
+
|
| 477 |
+
def save_evaluation_results(self, results: Dict) -> bool:
|
| 478 |
+
"""Save comprehensive evaluation results"""
|
| 479 |
+
try:
|
| 480 |
+
# Clean results for JSON serialization
|
| 481 |
+
clean_results = {}
|
| 482 |
+
for model_name, result in results.items():
|
| 483 |
+
if 'error' in result:
|
| 484 |
+
clean_results[model_name] = result
|
| 485 |
+
else:
|
| 486 |
+
clean_results[model_name] = {
|
| 487 |
+
'tuning_results': {
|
| 488 |
+
k: v for k, v in result['tuning_results'].items()
|
| 489 |
+
if k != 'best_estimator'
|
| 490 |
+
},
|
| 491 |
+
'evaluation_metrics': result['evaluation_metrics'],
|
| 492 |
+
'training_time': result['training_time']
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# Save results
|
| 496 |
+
with open(self.evaluation_path, 'w') as f:
|
| 497 |
+
json.dump(clean_results, f, indent=2, default=str)
|
| 498 |
+
|
| 499 |
+
logger.info(f"Evaluation results saved to {self.evaluation_path}")
|
| 500 |
+
return True
|
| 501 |
+
|
| 502 |
+
except Exception as e:
|
| 503 |
+
logger.error(f"Failed to save evaluation results: {str(e)}")
|
| 504 |
+
return False
|
| 505 |
+
|
| 506 |
+
def train_model(self, data_path: str = None) -> Tuple[bool, str]:
|
| 507 |
+
"""Main training function with comprehensive pipeline"""
|
| 508 |
+
try:
|
| 509 |
+
logger.info("Starting model training pipeline...")
|
| 510 |
+
|
| 511 |
+
# Override data path if provided
|
| 512 |
+
if data_path:
|
| 513 |
+
self.data_path = Path(data_path)
|
| 514 |
+
|
| 515 |
+
# Load and validate data
|
| 516 |
+
success, df, message = self.load_and_validate_data()
|
| 517 |
+
if not success:
|
| 518 |
+
return False, message
|
| 519 |
+
|
| 520 |
+
# Prepare data
|
| 521 |
+
X = df['text'].values
|
| 522 |
+
y = df['label'].values
|
| 523 |
+
|
| 524 |
+
# Train-test split
|
| 525 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 526 |
+
X, y,
|
| 527 |
+
test_size=self.test_size,
|
| 528 |
+
stratify=y,
|
| 529 |
+
random_state=self.random_state
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
logger.info(f"Data split: {len(X_train)} train, {len(X_test)} test")
|
| 533 |
+
|
| 534 |
+
# Train and evaluate models
|
| 535 |
+
results = self.train_and_evaluate_models(X_train, X_test, y_train, y_test)
|
| 536 |
+
|
| 537 |
+
# Select best model
|
| 538 |
+
best_model_name, best_model, best_metrics = self.select_best_model(results)
|
| 539 |
+
|
| 540 |
+
# Save model artifacts
|
| 541 |
+
if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
|
| 542 |
+
return False, "Failed to save model artifacts"
|
| 543 |
+
|
| 544 |
+
# Save evaluation results
|
| 545 |
+
self.save_evaluation_results(results)
|
| 546 |
+
|
| 547 |
+
success_message = (
|
| 548 |
+
f"Model training completed successfully. "
|
| 549 |
+
f"Best model: {best_model_name} "
|
| 550 |
+
f"(F1: {best_metrics['f1']:.4f}, Accuracy: {best_metrics['accuracy']:.4f})"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
logger.info(success_message)
|
| 554 |
+
return True, success_message
|
| 555 |
+
|
| 556 |
+
except Exception as e:
|
| 557 |
+
error_message = f"Model training failed: {str(e)}"
|
| 558 |
+
logger.error(error_message)
|
| 559 |
+
return False, error_message
|
| 560 |
|
| 561 |
def main():
|
| 562 |
+
"""Main execution function"""
|
| 563 |
+
trainer = RobustModelTrainer()
|
| 564 |
+
success, message = trainer.train_model()
|
| 565 |
+
|
| 566 |
+
if success:
|
| 567 |
+
print(f"✅ {message}")
|
| 568 |
+
else:
|
| 569 |
+
print(f"❌ {message}")
|
| 570 |
+
exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
if __name__ == "__main__":
|
| 573 |
+
main()
|