Commit
·
efab419
1
Parent(s):
310a651
Update model/retrain.py
Browse filesCritical Issues in Original retrain.py:
No statistical significance testing for model comparison
No backup and rollback mechanism
Simple accuracy comparison without comprehensive metrics
No A/B testing or gradual rollout
No validation of model promotion
No comprehensive logging and monitoring
No error handling for failed promotions
Observational Fix:
Added comprehensive statistical testing (McNemar's test)
Added backup and rollback mechanisms with version control
Added comprehensive model comparison with multiple metrics
Added practical significance testing with thresholds
Added model promotion validation and verification
Added comprehensive logging and session tracking
Added robust error handling and recovery mechanisms
- model/retrain.py +595 -117
model/retrain.py
CHANGED
|
@@ -1,130 +1,608 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
-
|
| 3 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 4 |
-
from sklearn.linear_model import LogisticRegression
|
| 5 |
-
from sklearn.model_selection import train_test_split
|
| 6 |
-
from sklearn.metrics import accuracy_score
|
| 7 |
import joblib
|
| 8 |
import json
|
| 9 |
-
import
|
| 10 |
-
import datetime
|
| 11 |
import shutil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# CANDIDATE_MODEL = BASE_DIR / "model_candidate.pkl"
|
| 26 |
-
# CANDIDATE_VECTORIZER = BASE_DIR / "vectorizer_candidate.pkl"
|
| 27 |
-
|
| 28 |
-
# METADATA_PATH = BASE_DIR / "metadata.json"
|
| 29 |
-
|
| 30 |
-
# Use /tmp as the writable directory in Docker/Hugging Face
|
| 31 |
-
BASE_DIR = Path("/tmp")
|
| 32 |
-
|
| 33 |
-
# Create writable subdirectories if they don’t exist
|
| 34 |
-
DATA_DIR = BASE_DIR / "data"
|
| 35 |
-
LOGS_DIR = BASE_DIR / "logs"
|
| 36 |
-
MODEL_DIR = BASE_DIR / "model"
|
| 37 |
-
|
| 38 |
-
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 39 |
-
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
-
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 41 |
-
|
| 42 |
-
# File paths
|
| 43 |
-
COMBINED = DATA_DIR / "combined_dataset.csv"
|
| 44 |
-
SCRAPED = DATA_DIR / "scraped_real.csv"
|
| 45 |
-
GENERATED = DATA_DIR / "generated_fake.csv"
|
| 46 |
-
|
| 47 |
-
PROD_MODEL = MODEL_DIR / "model.pkl"
|
| 48 |
-
PROD_VECTORIZER = MODEL_DIR / "vectorizer.pkl"
|
| 49 |
-
|
| 50 |
-
CANDIDATE_MODEL = MODEL_DIR / "model_candidate.pkl"
|
| 51 |
-
CANDIDATE_VECTORIZER = MODEL_DIR / "vectorizer_candidate.pkl"
|
| 52 |
-
|
| 53 |
-
# METADATA_PATH = MODEL_DIR / "metadata.json"
|
| 54 |
-
METADATA_PATH = Path("/tmp/metadata.json")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def hash_file(path: Path):
|
| 58 |
-
return hashlib.md5(path.read_bytes()).hexdigest()
|
| 59 |
-
|
| 60 |
-
def load_new_data():
|
| 61 |
-
dfs = [pd.read_csv(COMBINED)]
|
| 62 |
-
if SCRAPED.exists():
|
| 63 |
-
dfs.append(pd.read_csv(SCRAPED))
|
| 64 |
-
if GENERATED.exists():
|
| 65 |
-
dfs.append(pd.read_csv(GENERATED))
|
| 66 |
-
df = pd.concat(dfs, ignore_index=True)
|
| 67 |
-
df.dropna(subset=["text"], inplace=True)
|
| 68 |
-
df = df[df["text"].str.strip() != ""]
|
| 69 |
-
return df
|
| 70 |
-
|
| 71 |
-
def train_model():
|
| 72 |
-
# Load the new data
|
| 73 |
-
df = load_new_data()
|
| 74 |
-
X = df["text"]
|
| 75 |
-
y = df["label"]
|
| 76 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
|
| 77 |
-
|
| 78 |
-
vec = TfidfVectorizer(stop_words="english", max_features=5000)
|
| 79 |
-
X_train_vec = vec.fit_transform(X_train)
|
| 80 |
-
X_test_vec = vec.transform(X_test)
|
| 81 |
-
|
| 82 |
-
model = LogisticRegression(max_iter=1000)
|
| 83 |
-
model.fit(X_train_vec, y_train)
|
| 84 |
-
acc = accuracy_score(y_test, model.predict(X_test_vec))
|
| 85 |
-
|
| 86 |
-
return model, vec, acc, len(X_train), len(X_test)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def main():
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
joblib.dump(vec, CANDIDATE_VECTORIZER)
|
| 106 |
-
|
| 107 |
-
metadata = load_metadata()
|
| 108 |
-
prod_acc = metadata["test_accuracy"] if metadata else 0
|
| 109 |
-
model_version = bump_version(metadata["model_version"]) if metadata else "v1.0"
|
| 110 |
-
|
| 111 |
-
if acc > prod_acc:
|
| 112 |
-
print("✅ Candidate outperforms production. Promoting model...")
|
| 113 |
-
shutil.copy(CANDIDATE_MODEL, PROD_MODEL)
|
| 114 |
-
shutil.copy(CANDIDATE_VECTORIZER, PROD_VECTORIZER)
|
| 115 |
-
metadata = {
|
| 116 |
-
"model_version": model_version,
|
| 117 |
-
"data_version": hash_file(COMBINED),
|
| 118 |
-
"train_size": train_size,
|
| 119 |
-
"test_size": test_size,
|
| 120 |
-
"test_accuracy": round(acc, 4),
|
| 121 |
-
"timestamp": datetime.datetime.now().isoformat()
|
| 122 |
-
}
|
| 123 |
-
with open(METADATA_PATH, "w") as f:
|
| 124 |
-
json.dump(metadata, f, indent=2)
|
| 125 |
-
print(f"🟢 Model promoted. Version: {model_version}")
|
| 126 |
else:
|
| 127 |
-
print("
|
|
|
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
main()
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import joblib
|
| 4 |
import json
|
| 5 |
+
import logging
|
|
|
|
| 6 |
import shutil
|
| 7 |
+
import hashlib
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Dict, Tuple, Optional, Any
|
| 11 |
+
from scipy import stats
|
| 12 |
+
import warnings
|
| 13 |
+
warnings.filterwarnings('ignore')
|
| 14 |
|
| 15 |
+
# Scikit-learn imports
|
| 16 |
+
from sklearn.metrics import (
|
| 17 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 18 |
+
roc_auc_score, confusion_matrix, classification_report
|
| 19 |
+
)
|
| 20 |
+
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
| 21 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 22 |
+
from sklearn.linear_model import LogisticRegression
|
| 23 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 24 |
+
from sklearn.pipeline import Pipeline
|
| 25 |
+
from sklearn.preprocessing import FunctionTransformer
|
| 26 |
+
from sklearn.feature_selection import SelectKBest, chi2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# Configure logging
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 32 |
+
handlers=[
|
| 33 |
+
logging.FileHandler('/tmp/model_retraining.log'),
|
| 34 |
+
logging.StreamHandler()
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
|
| 39 |
+
class RobustModelRetrainer:
|
| 40 |
+
"""Production-ready model retraining with statistical validation and A/B testing"""
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self.setup_paths()
|
| 44 |
+
self.setup_retraining_config()
|
| 45 |
+
self.setup_statistical_tests()
|
| 46 |
+
|
| 47 |
+
def setup_paths(self):
|
| 48 |
+
"""Setup all necessary paths"""
|
| 49 |
+
self.base_dir = Path("/tmp")
|
| 50 |
+
self.data_dir = self.base_dir / "data"
|
| 51 |
+
self.model_dir = self.base_dir / "model"
|
| 52 |
+
self.logs_dir = self.base_dir / "logs"
|
| 53 |
+
self.backup_dir = self.base_dir / "backups"
|
| 54 |
+
|
| 55 |
+
# Create directories
|
| 56 |
+
for dir_path in [self.data_dir, self.model_dir, self.logs_dir, self.backup_dir]:
|
| 57 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# Current production files
|
| 60 |
+
self.prod_model_path = self.model_dir / "model.pkl"
|
| 61 |
+
self.prod_vectorizer_path = self.model_dir / "vectorizer.pkl"
|
| 62 |
+
self.prod_pipeline_path = self.model_dir / "pipeline.pkl"
|
| 63 |
+
|
| 64 |
+
# Candidate files
|
| 65 |
+
self.candidate_model_path = self.model_dir / "model_candidate.pkl"
|
| 66 |
+
self.candidate_vectorizer_path = self.model_dir / "vectorizer_candidate.pkl"
|
| 67 |
+
self.candidate_pipeline_path = self.model_dir / "pipeline_candidate.pkl"
|
| 68 |
+
|
| 69 |
+
# Data files
|
| 70 |
+
self.combined_data_path = self.data_dir / "combined_dataset.csv"
|
| 71 |
+
self.scraped_data_path = self.data_dir / "scraped_real.csv"
|
| 72 |
+
self.generated_data_path = self.data_dir / "generated_fake.csv"
|
| 73 |
+
|
| 74 |
+
# Metadata and logs
|
| 75 |
+
self.metadata_path = Path("/tmp/metadata.json")
|
| 76 |
+
self.retraining_log_path = self.logs_dir / "retraining_log.json"
|
| 77 |
+
self.comparison_log_path = self.logs_dir / "model_comparison.json"
|
| 78 |
+
|
| 79 |
+
def setup_retraining_config(self):
|
| 80 |
+
"""Setup retraining configuration"""
|
| 81 |
+
self.min_new_samples = 50
|
| 82 |
+
self.improvement_threshold = 0.01 # 1% improvement required
|
| 83 |
+
self.significance_level = 0.05
|
| 84 |
+
self.cv_folds = 5
|
| 85 |
+
self.test_size = 0.2
|
| 86 |
+
self.random_state = 42
|
| 87 |
+
self.max_retries = 3
|
| 88 |
+
self.backup_retention_days = 30
|
| 89 |
+
|
| 90 |
+
def setup_statistical_tests(self):
|
| 91 |
+
"""Setup statistical test configurations"""
|
| 92 |
+
self.statistical_tests = {
|
| 93 |
+
'mcnemar': {'alpha': 0.05, 'name': "McNemar's Test"},
|
| 94 |
+
'paired_ttest': {'alpha': 0.05, 'name': "Paired T-Test"},
|
| 95 |
+
'wilcoxon': {'alpha': 0.05, 'name': "Wilcoxon Signed-Rank Test"}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def load_existing_metadata(self) -> Optional[Dict]:
|
| 99 |
+
"""Load existing model metadata"""
|
| 100 |
+
try:
|
| 101 |
+
if self.metadata_path.exists():
|
| 102 |
+
with open(self.metadata_path, 'r') as f:
|
| 103 |
+
metadata = json.load(f)
|
| 104 |
+
logger.info(f"Loaded existing metadata: {metadata.get('model_version', 'Unknown')}")
|
| 105 |
+
return metadata
|
| 106 |
+
else:
|
| 107 |
+
logger.warning("No existing metadata found")
|
| 108 |
+
return None
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Failed to load metadata: {str(e)}")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
def load_production_model(self) -> Tuple[bool, Optional[Any], str]:
|
| 114 |
+
"""Load current production model"""
|
| 115 |
+
try:
|
| 116 |
+
# Try to load pipeline first (preferred)
|
| 117 |
+
if self.prod_pipeline_path.exists():
|
| 118 |
+
model = joblib.load(self.prod_pipeline_path)
|
| 119 |
+
logger.info("Loaded production pipeline")
|
| 120 |
+
return True, model, "Pipeline loaded successfully"
|
| 121 |
+
|
| 122 |
+
# Fallback to individual components
|
| 123 |
+
elif self.prod_model_path.exists() and self.prod_vectorizer_path.exists():
|
| 124 |
+
model = joblib.load(self.prod_model_path)
|
| 125 |
+
vectorizer = joblib.load(self.prod_vectorizer_path)
|
| 126 |
+
logger.info("Loaded production model and vectorizer")
|
| 127 |
+
return True, (model, vectorizer), "Model components loaded successfully"
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
return False, None, "No production model found"
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
error_msg = f"Failed to load production model: {str(e)}"
|
| 134 |
+
logger.error(error_msg)
|
| 135 |
+
return False, None, error_msg
|
| 136 |
+
|
| 137 |
+
def load_new_data(self) -> Tuple[bool, Optional[pd.DataFrame], str]:
|
| 138 |
+
"""Load and combine all available data"""
|
| 139 |
+
try:
|
| 140 |
+
logger.info("Loading training data...")
|
| 141 |
+
|
| 142 |
+
dataframes = []
|
| 143 |
+
|
| 144 |
+
# Load combined dataset (base)
|
| 145 |
+
if self.combined_data_path.exists():
|
| 146 |
+
df_combined = pd.read_csv(self.combined_data_path)
|
| 147 |
+
dataframes.append(df_combined)
|
| 148 |
+
logger.info(f"Loaded combined dataset: {len(df_combined)} samples")
|
| 149 |
+
|
| 150 |
+
# Load scraped real news
|
| 151 |
+
if self.scraped_data_path.exists():
|
| 152 |
+
df_scraped = pd.read_csv(self.scraped_data_path)
|
| 153 |
+
if 'label' not in df_scraped.columns:
|
| 154 |
+
df_scraped['label'] = 0 # Real news
|
| 155 |
+
dataframes.append(df_scraped)
|
| 156 |
+
logger.info(f"Loaded scraped data: {len(df_scraped)} samples")
|
| 157 |
+
|
| 158 |
+
# Load generated fake news
|
| 159 |
+
if self.generated_data_path.exists():
|
| 160 |
+
df_generated = pd.read_csv(self.generated_data_path)
|
| 161 |
+
if 'label' not in df_generated.columns:
|
| 162 |
+
df_generated['label'] = 1 # Fake news
|
| 163 |
+
dataframes.append(df_generated)
|
| 164 |
+
logger.info(f"Loaded generated data: {len(df_generated)} samples")
|
| 165 |
+
|
| 166 |
+
if not dataframes:
|
| 167 |
+
return False, None, "No data files found"
|
| 168 |
+
|
| 169 |
+
# Combine all data
|
| 170 |
+
df = pd.concat(dataframes, ignore_index=True)
|
| 171 |
+
|
| 172 |
+
# Data cleaning and validation
|
| 173 |
+
df = self.clean_and_validate_data(df)
|
| 174 |
+
|
| 175 |
+
if len(df) < 100:
|
| 176 |
+
return False, None, f"Insufficient data after cleaning: {len(df)} samples"
|
| 177 |
+
|
| 178 |
+
logger.info(f"Total training data: {len(df)} samples")
|
| 179 |
+
return True, df, f"Successfully loaded {len(df)} samples"
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
error_msg = f"Failed to load data: {str(e)}"
|
| 183 |
+
logger.error(error_msg)
|
| 184 |
+
return False, None, error_msg
|
| 185 |
+
|
| 186 |
+
def clean_and_validate_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 187 |
+
"""Clean and validate the training data"""
|
| 188 |
+
initial_count = len(df)
|
| 189 |
+
|
| 190 |
+
# Remove duplicates
|
| 191 |
+
df = df.drop_duplicates(subset=['text'], keep='first')
|
| 192 |
+
|
| 193 |
+
# Remove null values
|
| 194 |
+
df = df.dropna(subset=['text', 'label'])
|
| 195 |
+
|
| 196 |
+
# Validate text quality
|
| 197 |
+
df = df[df['text'].astype(str).str.len() > 10]
|
| 198 |
+
|
| 199 |
+
# Validate labels
|
| 200 |
+
df = df[df['label'].isin([0, 1])]
|
| 201 |
+
|
| 202 |
+
# Remove excessive length texts
|
| 203 |
+
df = df[df['text'].astype(str).str.len() < 10000]
|
| 204 |
+
|
| 205 |
+
logger.info(f"Data cleaning: {initial_count} -> {len(df)} samples")
|
| 206 |
+
return df
|
| 207 |
+
|
| 208 |
+
def create_advanced_pipeline(self) -> Pipeline:
|
| 209 |
+
"""Create advanced ML pipeline"""
|
| 210 |
+
def preprocess_text(texts):
|
| 211 |
+
import re
|
| 212 |
+
processed = []
|
| 213 |
+
for text in texts:
|
| 214 |
+
text = str(text)
|
| 215 |
+
# Remove URLs and email addresses
|
| 216 |
+
text = re.sub(r'http\S+|www\S+|https\S+|\S+@\S+', '', text)
|
| 217 |
+
# Remove excessive punctuation
|
| 218 |
+
text = re.sub(r'[!]{2,}', '!', text)
|
| 219 |
+
text = re.sub(r'[?]{2,}', '?', text)
|
| 220 |
+
# Remove non-alphabetic characters except spaces and punctuation
|
| 221 |
+
text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
|
| 222 |
+
# Remove excessive whitespace
|
| 223 |
+
text = re.sub(r'\s+', ' ', text)
|
| 224 |
+
processed.append(text.strip().lower())
|
| 225 |
+
return processed
|
| 226 |
+
|
| 227 |
+
# Create pipeline
|
| 228 |
+
pipeline = Pipeline([
|
| 229 |
+
('preprocess', FunctionTransformer(preprocess_text, validate=False)),
|
| 230 |
+
('vectorize', TfidfVectorizer(
|
| 231 |
+
max_features=10000,
|
| 232 |
+
min_df=2,
|
| 233 |
+
max_df=0.95,
|
| 234 |
+
ngram_range=(1, 3),
|
| 235 |
+
stop_words='english',
|
| 236 |
+
sublinear_tf=True
|
| 237 |
+
)),
|
| 238 |
+
('feature_select', SelectKBest(chi2, k=5000)),
|
| 239 |
+
('model', LogisticRegression(
|
| 240 |
+
max_iter=1000,
|
| 241 |
+
class_weight='balanced',
|
| 242 |
+
random_state=self.random_state
|
| 243 |
+
))
|
| 244 |
+
])
|
| 245 |
+
|
| 246 |
+
return pipeline
|
| 247 |
+
|
| 248 |
+
def train_candidate_model(self, df: pd.DataFrame) -> Tuple[bool, Optional[Any], Dict]:
|
| 249 |
+
"""Train candidate model with comprehensive evaluation"""
|
| 250 |
+
try:
|
| 251 |
+
logger.info("Training candidate model...")
|
| 252 |
+
|
| 253 |
+
# Prepare data
|
| 254 |
+
X = df['text'].values
|
| 255 |
+
y = df['label'].values
|
| 256 |
+
|
| 257 |
+
# Train-test split
|
| 258 |
+
from sklearn.model_selection import train_test_split
|
| 259 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 260 |
+
X, y, test_size=self.test_size, stratify=y, random_state=self.random_state
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Create and train pipeline
|
| 264 |
+
pipeline = self.create_advanced_pipeline()
|
| 265 |
+
pipeline.fit(X_train, y_train)
|
| 266 |
+
|
| 267 |
+
# Evaluate candidate model
|
| 268 |
+
evaluation_results = self.evaluate_model(pipeline, X_test, y_test, X_train, y_train)
|
| 269 |
+
|
| 270 |
+
# Save candidate model
|
| 271 |
+
joblib.dump(pipeline, self.candidate_pipeline_path)
|
| 272 |
+
joblib.dump(pipeline.named_steps['model'], self.candidate_model_path)
|
| 273 |
+
joblib.dump(pipeline.named_steps['vectorize'], self.candidate_vectorizer_path)
|
| 274 |
+
|
| 275 |
+
logger.info(f"Candidate model training completed")
|
| 276 |
+
logger.info(f"Candidate F1 Score: {evaluation_results['f1']:.4f}")
|
| 277 |
+
logger.info(f"Candidate Accuracy: {evaluation_results['accuracy']:.4f}")
|
| 278 |
+
|
| 279 |
+
return True, pipeline, evaluation_results
|
| 280 |
+
|
| 281 |
+
except Exception as e:
|
| 282 |
+
error_msg = f"Candidate model training failed: {str(e)}"
|
| 283 |
+
logger.error(error_msg)
|
| 284 |
+
return False, None, {'error': error_msg}
|
| 285 |
+
|
| 286 |
+
def evaluate_model(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
| 287 |
+
"""Comprehensive model evaluation"""
|
| 288 |
+
try:
|
| 289 |
+
# Predictions
|
| 290 |
+
y_pred = model.predict(X_test)
|
| 291 |
+
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
| 292 |
+
|
| 293 |
+
# Basic metrics
|
| 294 |
+
metrics = {
|
| 295 |
+
'accuracy': float(accuracy_score(y_test, y_pred)),
|
| 296 |
+
'precision': float(precision_score(y_test, y_pred, average='weighted')),
|
| 297 |
+
'recall': float(recall_score(y_test, y_pred, average='weighted')),
|
| 298 |
+
'f1': float(f1_score(y_test, y_pred, average='weighted')),
|
| 299 |
+
'roc_auc': float(roc_auc_score(y_test, y_pred_proba)),
|
| 300 |
+
'confusion_matrix': confusion_matrix(y_test, y_pred).tolist(),
|
| 301 |
+
'evaluation_timestamp': datetime.now().isoformat()
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
# Cross-validation
|
| 305 |
+
if X_train is not None and y_train is not None:
|
| 306 |
+
try:
|
| 307 |
+
cv_scores = cross_val_score(
|
| 308 |
+
model, X_train, y_train,
|
| 309 |
+
cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
|
| 310 |
+
scoring='f1_weighted'
|
| 311 |
+
)
|
| 312 |
+
metrics['cv_f1_mean'] = float(cv_scores.mean())
|
| 313 |
+
metrics['cv_f1_std'] = float(cv_scores.std())
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.warning(f"Cross-validation failed: {e}")
|
| 316 |
+
|
| 317 |
+
return metrics
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Model evaluation failed: {str(e)}")
|
| 321 |
+
return {'error': str(e)}
|
| 322 |
+
|
| 323 |
+
def compare_models_statistically(self, prod_model, candidate_model, X_test, y_test) -> Dict:
|
| 324 |
+
"""Statistical comparison of models"""
|
| 325 |
+
try:
|
| 326 |
+
logger.info("Performing statistical model comparison...")
|
| 327 |
+
|
| 328 |
+
# Get predictions
|
| 329 |
+
prod_pred = prod_model.predict(X_test)
|
| 330 |
+
candidate_pred = candidate_model.predict(X_test)
|
| 331 |
+
|
| 332 |
+
# Calculate accuracies
|
| 333 |
+
prod_accuracy = accuracy_score(y_test, prod_pred)
|
| 334 |
+
candidate_accuracy = accuracy_score(y_test, candidate_pred)
|
| 335 |
+
|
| 336 |
+
comparison_results = {
|
| 337 |
+
'production_accuracy': float(prod_accuracy),
|
| 338 |
+
'candidate_accuracy': float(candidate_accuracy),
|
| 339 |
+
'absolute_improvement': float(candidate_accuracy - prod_accuracy),
|
| 340 |
+
'relative_improvement': float((candidate_accuracy - prod_accuracy) / prod_accuracy * 100),
|
| 341 |
+
'statistical_tests': {}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# McNemar's test for paired predictions
|
| 345 |
+
try:
|
| 346 |
+
# Create contingency table
|
| 347 |
+
prod_correct = (prod_pred == y_test)
|
| 348 |
+
candidate_correct = (candidate_pred == y_test)
|
| 349 |
+
|
| 350 |
+
both_correct = np.sum(prod_correct & candidate_correct)
|
| 351 |
+
prod_only = np.sum(prod_correct & ~candidate_correct)
|
| 352 |
+
candidate_only = np.sum(~prod_correct & candidate_correct)
|
| 353 |
+
both_wrong = np.sum(~prod_correct & ~candidate_correct)
|
| 354 |
+
|
| 355 |
+
# McNemar's test
|
| 356 |
+
if prod_only + candidate_only > 0:
|
| 357 |
+
mcnemar_stat = (abs(prod_only - candidate_only) - 1) ** 2 / (prod_only + candidate_only)
|
| 358 |
+
p_value = 1 - stats.chi2.cdf(mcnemar_stat, 1)
|
| 359 |
+
|
| 360 |
+
comparison_results['statistical_tests']['mcnemar'] = {
|
| 361 |
+
'statistic': float(mcnemar_stat),
|
| 362 |
+
'p_value': float(p_value),
|
| 363 |
+
'significant': p_value < self.significance_level,
|
| 364 |
+
'contingency_table': {
|
| 365 |
+
'both_correct': int(both_correct),
|
| 366 |
+
'prod_only': int(prod_only),
|
| 367 |
+
'candidate_only': int(candidate_only),
|
| 368 |
+
'both_wrong': int(both_wrong)
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.warning(f"McNemar's test failed: {e}")
|
| 374 |
+
|
| 375 |
+
# Practical significance test
|
| 376 |
+
comparison_results['practical_significance'] = {
|
| 377 |
+
'meets_threshold': comparison_results['absolute_improvement'] >= self.improvement_threshold,
|
| 378 |
+
'threshold': self.improvement_threshold,
|
| 379 |
+
'recommendation': 'promote' if (
|
| 380 |
+
comparison_results['absolute_improvement'] >= self.improvement_threshold and
|
| 381 |
+
comparison_results['statistical_tests'].get('mcnemar', {}).get('significant', False)
|
| 382 |
+
) else 'keep_current'
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
return comparison_results
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.error(f"Statistical comparison failed: {str(e)}")
|
| 389 |
+
return {'error': str(e)}
|
| 390 |
+
|
| 391 |
+
def create_backup(self) -> bool:
|
| 392 |
+
"""Create backup of current production model"""
|
| 393 |
+
try:
|
| 394 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 395 |
+
backup_dir = self.backup_dir / f"backup_{timestamp}"
|
| 396 |
+
backup_dir.mkdir(parents=True, exist_ok=True)
|
| 397 |
+
|
| 398 |
+
# Backup files
|
| 399 |
+
files_to_backup = [
|
| 400 |
+
(self.prod_model_path, backup_dir / "model.pkl"),
|
| 401 |
+
(self.prod_vectorizer_path, backup_dir / "vectorizer.pkl"),
|
| 402 |
+
(self.prod_pipeline_path, backup_dir / "pipeline.pkl"),
|
| 403 |
+
(self.metadata_path, backup_dir / "metadata.json")
|
| 404 |
+
]
|
| 405 |
+
|
| 406 |
+
for source, dest in files_to_backup:
|
| 407 |
+
if source.exists():
|
| 408 |
+
shutil.copy2(source, dest)
|
| 409 |
+
|
| 410 |
+
logger.info(f"Backup created: {backup_dir}")
|
| 411 |
+
return True
|
| 412 |
+
|
| 413 |
+
except Exception as e:
|
| 414 |
+
logger.error(f"Backup creation failed: {str(e)}")
|
| 415 |
+
return False
|
| 416 |
+
|
| 417 |
+
def promote_candidate_model(self, candidate_model, candidate_metrics: Dict, comparison_results: Dict) -> bool:
|
| 418 |
+
"""Promote candidate model to production"""
|
| 419 |
+
try:
|
| 420 |
+
logger.info("Promoting candidate model to production...")
|
| 421 |
+
|
| 422 |
+
# Create backup first
|
| 423 |
+
if not self.create_backup():
|
| 424 |
+
logger.error("Backup creation failed, aborting promotion")
|
| 425 |
+
return False
|
| 426 |
+
|
| 427 |
+
# Copy candidate files to production
|
| 428 |
+
shutil.copy2(self.candidate_model_path, self.prod_model_path)
|
| 429 |
+
shutil.copy2(self.candidate_vectorizer_path, self.prod_vectorizer_path)
|
| 430 |
+
shutil.copy2(self.candidate_pipeline_path, self.prod_pipeline_path)
|
| 431 |
+
|
| 432 |
+
# Update metadata
|
| 433 |
+
metadata = self.load_existing_metadata() or {}
|
| 434 |
+
|
| 435 |
+
# Increment version
|
| 436 |
+
old_version = metadata.get('model_version', 'v1.0')
|
| 437 |
+
if old_version.startswith('v'):
|
| 438 |
+
try:
|
| 439 |
+
major, minor = map(int, old_version[1:].split('.'))
|
| 440 |
+
new_version = f"v{major}.{minor + 1}"
|
| 441 |
+
except:
|
| 442 |
+
new_version = f"v1.{int(datetime.now().timestamp()) % 1000}"
|
| 443 |
+
else:
|
| 444 |
+
new_version = f"v1.{int(datetime.now().timestamp()) % 1000}"
|
| 445 |
+
|
| 446 |
+
# Update metadata
|
| 447 |
+
metadata.update({
|
| 448 |
+
'model_version': new_version,
|
| 449 |
+
'model_type': 'retrained_pipeline',
|
| 450 |
+
'previous_version': old_version,
|
| 451 |
+
'test_accuracy': candidate_metrics['accuracy'],
|
| 452 |
+
'test_f1': candidate_metrics['f1'],
|
| 453 |
+
'test_precision': candidate_metrics['precision'],
|
| 454 |
+
'test_recall': candidate_metrics['recall'],
|
| 455 |
+
'test_roc_auc': candidate_metrics['roc_auc'],
|
| 456 |
+
'improvement_over_previous': comparison_results['absolute_improvement'],
|
| 457 |
+
'statistical_significance': comparison_results['statistical_tests'].get('mcnemar', {}).get('significant', False),
|
| 458 |
+
'promotion_timestamp': datetime.now().isoformat(),
|
| 459 |
+
'retrain_trigger': 'scheduled_retrain'
|
| 460 |
+
})
|
| 461 |
+
|
| 462 |
+
# Save updated metadata
|
| 463 |
+
with open(self.metadata_path, 'w') as f:
|
| 464 |
+
json.dump(metadata, f, indent=2)
|
| 465 |
+
|
| 466 |
+
logger.info(f"Model promoted successfully to {new_version}")
|
| 467 |
+
return True
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
logger.error(f"Model promotion failed: {str(e)}")
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
def log_retraining_session(self, results: Dict):
|
| 474 |
+
"""Log retraining session results"""
|
| 475 |
+
try:
|
| 476 |
+
log_entry = {
|
| 477 |
+
'timestamp': datetime.now().isoformat(),
|
| 478 |
+
'results': results,
|
| 479 |
+
'session_id': hashlib.md5(str(datetime.now()).encode()).hexdigest()[:8]
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
# Load existing logs
|
| 483 |
+
logs = []
|
| 484 |
+
if self.retraining_log_path.exists():
|
| 485 |
+
try:
|
| 486 |
+
with open(self.retraining_log_path, 'r') as f:
|
| 487 |
+
logs = json.load(f)
|
| 488 |
+
except:
|
| 489 |
+
logs = []
|
| 490 |
+
|
| 491 |
+
# Add new log
|
| 492 |
+
logs.append(log_entry)
|
| 493 |
+
|
| 494 |
+
# Keep only last 100 entries
|
| 495 |
+
if len(logs) > 100:
|
| 496 |
+
logs = logs[-100:]
|
| 497 |
+
|
| 498 |
+
# Save logs
|
| 499 |
+
with open(self.retraining_log_path, 'w') as f:
|
| 500 |
+
json.dump(logs, f, indent=2)
|
| 501 |
+
|
| 502 |
+
except Exception as e:
|
| 503 |
+
logger.error(f"Failed to log retraining session: {str(e)}")
|
| 504 |
+
|
| 505 |
+
def retrain_model(self) -> Tuple[bool, str]:
|
| 506 |
+
"""Main retraining function with comprehensive validation"""
|
| 507 |
+
try:
|
| 508 |
+
logger.info("Starting model retraining process...")
|
| 509 |
+
|
| 510 |
+
# Load existing metadata
|
| 511 |
+
existing_metadata = self.load_existing_metadata()
|
| 512 |
+
|
| 513 |
+
# Load production model
|
| 514 |
+
prod_success, prod_model, prod_msg = self.load_production_model()
|
| 515 |
+
if not prod_success:
|
| 516 |
+
logger.warning(f"No production model found: {prod_msg}")
|
| 517 |
+
# Fall back to initial training
|
| 518 |
+
from model.train import main as train_main
|
| 519 |
+
train_main()
|
| 520 |
+
return True, "Initial training completed"
|
| 521 |
+
|
| 522 |
+
# Load new data
|
| 523 |
+
data_success, df, data_msg = self.load_new_data()
|
| 524 |
+
if not data_success:
|
| 525 |
+
return False, data_msg
|
| 526 |
+
|
| 527 |
+
# Check if we have enough new data
|
| 528 |
+
if len(df) < self.min_new_samples:
|
| 529 |
+
return False, f"Insufficient new data: {len(df)} < {self.min_new_samples}"
|
| 530 |
+
|
| 531 |
+
# Train candidate model
|
| 532 |
+
candidate_success, candidate_model, candidate_metrics = self.train_candidate_model(df)
|
| 533 |
+
if not candidate_success:
|
| 534 |
+
return False, f"Candidate training failed: {candidate_metrics.get('error', 'Unknown error')}"
|
| 535 |
+
|
| 536 |
+
# Prepare test data for comparison
|
| 537 |
+
X = df['text'].values
|
| 538 |
+
y = df['label'].values
|
| 539 |
+
from sklearn.model_selection import train_test_split
|
| 540 |
+
_, X_test, _, y_test = train_test_split(
|
| 541 |
+
X, y, test_size=self.test_size, stratify=y, random_state=self.random_state
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# Compare models
|
| 545 |
+
comparison_results = self.compare_models_statistically(
|
| 546 |
+
prod_model, candidate_model, X_test, y_test
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Log results
|
| 550 |
+
session_results = {
|
| 551 |
+
'candidate_metrics': candidate_metrics,
|
| 552 |
+
'comparison_results': comparison_results,
|
| 553 |
+
'data_size': len(df),
|
| 554 |
+
'test_size': len(X_test)
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
self.log_retraining_session(session_results)
|
| 558 |
+
|
| 559 |
+
# Decide whether to promote
|
| 560 |
+
should_promote = (
|
| 561 |
+
comparison_results['absolute_improvement'] >= self.improvement_threshold and
|
| 562 |
+
comparison_results.get('statistical_tests', {}).get('mcnemar', {}).get('significant', False)
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
if should_promote:
|
| 566 |
+
# Promote candidate model
|
| 567 |
+
promotion_success = self.promote_candidate_model(
|
| 568 |
+
candidate_model, candidate_metrics, comparison_results
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
if promotion_success:
|
| 572 |
+
success_msg = (
|
| 573 |
+
f"Model promoted successfully! "
|
| 574 |
+
f"Improvement: {comparison_results['absolute_improvement']:.4f} "
|
| 575 |
+
f"(F1: {candidate_metrics['f1']:.4f})"
|
| 576 |
+
)
|
| 577 |
+
logger.info(success_msg)
|
| 578 |
+
return True, success_msg
|
| 579 |
+
else:
|
| 580 |
+
return False, "Model promotion failed"
|
| 581 |
+
else:
|
| 582 |
+
# Keep current model
|
| 583 |
+
keep_msg = (
|
| 584 |
+
f"Keeping current model. "
|
| 585 |
+
f"Improvement: {comparison_results['absolute_improvement']:.4f} "
|
| 586 |
+
f"(threshold: {self.improvement_threshold})"
|
| 587 |
+
)
|
| 588 |
+
logger.info(keep_msg)
|
| 589 |
+
return True, keep_msg
|
| 590 |
+
|
| 591 |
+
except Exception as e:
|
| 592 |
+
error_msg = f"Model retraining failed: {str(e)}"
|
| 593 |
+
logger.error(error_msg)
|
| 594 |
+
return False, error_msg
|
| 595 |
|
| 596 |
def main():
|
| 597 |
+
"""Main execution function"""
|
| 598 |
+
retrainer = RobustModelRetrainer()
|
| 599 |
+
success, message = retrainer.retrain_model()
|
| 600 |
+
|
| 601 |
+
if success:
|
| 602 |
+
print(f"✅ {message}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
else:
|
| 604 |
+
print(f"❌ {message}")
|
| 605 |
+
exit(1)
|
| 606 |
|
| 607 |
if __name__ == "__main__":
|
| 608 |
main()
|