Ahmedik95316 commited on
Commit
310a651
·
1 Parent(s): 9d61526

Update model/train.py

Browse files

Critical Issues in Original train.py:

Single evaluation metric (accuracy only)
No hyperparameter tuning
No cross-validation
Limited preprocessing
No feature selection
No model comparison
No comprehensive evaluation
No overfitting detection

Observational Fix:

Added comprehensive evaluation with multiple metrics
Added hyperparameter tuning with GridSearchCV
Added cross-validation for robust evaluation
Added advanced text preprocessing pipeline
Added feature selection and optimization
Added model comparison (Logistic Regression vs Random Forest)
Added comprehensive evaluation with confusion matrix, ROC curves
Added overfitting detection and model validation

Files changed (1) hide show
  1. model/train.py +562 -76
model/train.py CHANGED
@@ -1,87 +1,573 @@
1
  import pandas as pd
 
2
  from pathlib import Path
3
- from sklearn.feature_extraction.text import TfidfVectorizer
4
- from sklearn.linear_model import LogisticRegression
5
- from sklearn.metrics import accuracy_score
6
- from sklearn.model_selection import train_test_split
7
- import joblib
8
  import json
9
- import datetime
10
  import hashlib
 
 
 
 
11
 
12
- # # Paths
13
- # BASE_DIR = Path(__file__).resolve().parent
14
- # DATA_PATH = BASE_DIR.parent / "data" / "combined_dataset.csv"
15
- # MODEL_PATH = BASE_DIR / "model.pkl"
16
- # VECTORIZER_PATH = BASE_DIR / "vectorizer.pkl"
17
- # METADATA_PATH = BASE_DIR / "metadata.json"
18
-
19
- # Base dir and data location inside /tmp
20
- BASE_DIR = Path("/tmp")
21
- DATA_PATH = BASE_DIR / "data" / "combined_dataset.csv"
22
-
23
- # Model artifacts also in /tmp (or you can keep these in /app/model if you want to persist them in the container)
24
- MODEL_PATH = BASE_DIR / "model.pkl"
25
- VECTORIZER_PATH = BASE_DIR / "vectorizer.pkl"
26
- # METADATA_PATH = BASE_DIR / "metadata.json"
27
- METADATA_PATH = Path("/tmp/metadata.json")
 
 
28
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def hash_file(filepath):
31
- content = Path(filepath).read_bytes()
32
- return hashlib.md5(content).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def main():
35
- # Load dataset
36
- # print('Dataset Loaded')
37
- df = pd.read_csv(DATA_PATH)
38
- X = df['text']
39
- y = df['label']
40
-
41
- # Train-test split
42
- X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
43
-
44
- # Vectorize
45
- vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
46
- X_train_vec = vectorizer.fit_transform(X_train)
47
- X_test_vec = vectorizer.transform(X_test)
48
-
49
- # print('Train/Test Splits Created')
50
- # print('Starting Model Training')
51
-
52
- # Train model
53
- model = LogisticRegression(max_iter=1000)
54
- model.fit(X_train_vec, y_train)
55
-
56
- # print('Model Training Completed')
57
- #print('Model Evaluation Starting!')
58
-
59
- # Evaluate
60
- y_pred = model.predict(X_test_vec)
61
- acc = accuracy_score(y_test, y_pred)
62
-
63
- # Save model + vectorizer
64
- joblib.dump(model, MODEL_PATH)
65
- joblib.dump(vectorizer, VECTORIZER_PATH)
66
-
67
- # print('Model Evaluation Done')
68
- # print('Model Saved!')
69
-
70
- # Save metadata
71
- metadata = {
72
- "model_version": f"v1.0",
73
- "data_version": hash_file(DATA_PATH),
74
- "train_size": len(X_train),
75
- "test_size": len(X_test),
76
- "test_accuracy": round(acc, 4),
77
- "timestamp": datetime.datetime.now().isoformat()
78
- }
79
- with open(METADATA_PATH, 'w') as f:
80
- json.dump(metadata, f, indent=4)
81
-
82
- print(f"✅ Model trained and saved.")
83
- print(f"📊 Test Accuracy: {acc:.4f}")
84
- print(f"📝 Metadata saved to {METADATA_PATH}")
85
 
86
  if __name__ == "__main__":
87
- main()
 
1
  import pandas as pd
2
+ import numpy as np
3
  from pathlib import Path
4
+ import logging
 
 
 
 
5
  import json
6
+ import joblib
7
  import hashlib
8
+ from datetime import datetime
9
+ from typing import Dict, Tuple, Optional, Any
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
 
13
+ # Scikit-learn imports
14
+ from sklearn.feature_extraction.text import TfidfVectorizer
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.ensemble import RandomForestClassifier
17
+ from sklearn.model_selection import (
18
+ train_test_split, cross_val_score, GridSearchCV,
19
+ StratifiedKFold, validation_curve
20
+ )
21
+ from sklearn.metrics import (
22
+ accuracy_score, precision_score, recall_score, f1_score,
23
+ roc_auc_score, confusion_matrix, classification_report,
24
+ precision_recall_curve, roc_curve
25
+ )
26
+ from sklearn.pipeline import Pipeline
27
+ from sklearn.preprocessing import FunctionTransformer
28
+ from sklearn.feature_selection import SelectKBest, chi2
29
+ import matplotlib.pyplot as plt
30
+ import seaborn as sns
31
 
32
+ # Configure logging
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(levelname)s - %(message)s',
36
+ handlers=[
37
+ logging.FileHandler('/tmp/model_training.log'),
38
+ logging.StreamHandler()
39
+ ]
40
+ )
41
+ logger = logging.getLogger(__name__)
42
 
43
+ class RobustModelTrainer:
44
+ """Production-ready model trainer with comprehensive evaluation and validation"""
45
+
46
+ def __init__(self):
47
+ self.setup_paths()
48
+ self.setup_training_config()
49
+ self.setup_models()
50
+
51
+ def setup_paths(self):
52
+ """Setup all necessary paths"""
53
+ self.base_dir = Path("/tmp")
54
+ self.data_dir = self.base_dir / "data"
55
+ self.model_dir = self.base_dir / "model"
56
+ self.results_dir = self.base_dir / "results"
57
+
58
+ # Create directories
59
+ for dir_path in [self.data_dir, self.model_dir, self.results_dir]:
60
+ dir_path.mkdir(parents=True, exist_ok=True)
61
+
62
+ # File paths
63
+ self.data_path = self.data_dir / "combined_dataset.csv"
64
+ self.model_path = self.model_dir / "model.pkl"
65
+ self.vectorizer_path = self.model_dir / "vectorizer.pkl"
66
+ self.pipeline_path = self.model_dir / "pipeline.pkl"
67
+ self.metadata_path = Path("/tmp/metadata.json")
68
+ self.evaluation_path = self.results_dir / "evaluation_results.json"
69
+
70
+ def setup_training_config(self):
71
+ """Setup training configuration"""
72
+ self.test_size = 0.2
73
+ self.validation_size = 0.1
74
+ self.random_state = 42
75
+ self.cv_folds = 5
76
+ self.max_features = 10000
77
+ self.min_df = 2
78
+ self.max_df = 0.95
79
+ self.ngram_range = (1, 3)
80
+ self.max_iter = 1000
81
+ self.class_weight = 'balanced'
82
+ self.feature_selection_k = 5000
83
+
84
+ def setup_models(self):
85
+ """Setup model configurations for comparison"""
86
+ self.models = {
87
+ 'logistic_regression': {
88
+ 'model': LogisticRegression(
89
+ max_iter=self.max_iter,
90
+ class_weight=self.class_weight,
91
+ random_state=self.random_state
92
+ ),
93
+ 'param_grid': {
94
+ 'model__C': [0.1, 1, 10, 100],
95
+ 'model__penalty': ['l2'],
96
+ 'model__solver': ['liblinear', 'lbfgs']
97
+ }
98
+ },
99
+ 'random_forest': {
100
+ 'model': RandomForestClassifier(
101
+ n_estimators=100,
102
+ class_weight=self.class_weight,
103
+ random_state=self.random_state
104
+ ),
105
+ 'param_grid': {
106
+ 'model__n_estimators': [50, 100, 200],
107
+ 'model__max_depth': [10, 20, None],
108
+ 'model__min_samples_split': [2, 5, 10]
109
+ }
110
+ }
111
+ }
112
+
113
+ def load_and_validate_data(self) -> Tuple[bool, Optional[pd.DataFrame], str]:
114
+ """Load and validate training data"""
115
+ try:
116
+ logger.info("Loading training data...")
117
+
118
+ if not self.data_path.exists():
119
+ return False, None, f"Data file not found: {self.data_path}"
120
+
121
+ # Load data
122
+ df = pd.read_csv(self.data_path)
123
+
124
+ # Basic validation
125
+ if df.empty:
126
+ return False, None, "Dataset is empty"
127
+
128
+ required_columns = ['text', 'label']
129
+ missing_columns = [col for col in required_columns if col not in df.columns]
130
+ if missing_columns:
131
+ return False, None, f"Missing required columns: {missing_columns}"
132
+
133
+ # Remove missing values
134
+ initial_count = len(df)
135
+ df = df.dropna(subset=required_columns)
136
+ if len(df) < initial_count:
137
+ logger.warning(f"Removed {initial_count - len(df)} rows with missing values")
138
+
139
+ # Validate text content
140
+ df = df[df['text'].astype(str).str.len() > 10]
141
+
142
+ # Validate labels
143
+ unique_labels = df['label'].unique()
144
+ if len(unique_labels) < 2:
145
+ return False, None, f"Need at least 2 classes, found: {unique_labels}"
146
+
147
+ # Check minimum sample size
148
+ if len(df) < 100:
149
+ return False, None, f"Insufficient samples for training: {len(df)}"
150
+
151
+ # Check class balance
152
+ label_counts = df['label'].value_counts()
153
+ min_class_ratio = label_counts.min() / label_counts.max()
154
+ if min_class_ratio < 0.1:
155
+ logger.warning(f"Severe class imbalance detected: {min_class_ratio:.3f}")
156
+
157
+ logger.info(f"Data validation successful: {len(df)} samples, {len(unique_labels)} classes")
158
+ logger.info(f"Class distribution: {label_counts.to_dict()}")
159
+
160
+ return True, df, "Data loaded successfully"
161
+
162
+ except Exception as e:
163
+ error_msg = f"Error loading data: {str(e)}"
164
+ logger.error(error_msg)
165
+ return False, None, error_msg
166
+
167
+ def preprocess_text(self, text):
168
+ """Advanced text preprocessing"""
169
+ import re
170
+
171
+ # Convert to string
172
+ text = str(text)
173
+
174
+ # Remove URLs
175
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
176
+
177
+ # Remove email addresses
178
+ text = re.sub(r'\S+@\S+', '', text)
179
+
180
+ # Remove excessive punctuation
181
+ text = re.sub(r'[!]{2,}', '!', text)
182
+ text = re.sub(r'[?]{2,}', '?', text)
183
+ text = re.sub(r'[.]{3,}', '...', text)
184
+
185
+ # Remove non-alphabetic characters except spaces and basic punctuation
186
+ text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
187
+
188
+ # Remove excessive whitespace
189
+ text = re.sub(r'\s+', ' ', text)
190
+
191
+ return text.strip().lower()
192
+
193
+ def create_preprocessing_pipeline(self) -> Pipeline:
194
+ """Create advanced preprocessing pipeline"""
195
+ # Text preprocessing
196
+ text_preprocessor = FunctionTransformer(
197
+ func=lambda x: [self.preprocess_text(text) for text in x],
198
+ validate=False
199
+ )
200
+
201
+ # TF-IDF vectorization
202
+ vectorizer = TfidfVectorizer(
203
+ max_features=self.max_features,
204
+ min_df=self.min_df,
205
+ max_df=self.max_df,
206
+ ngram_range=self.ngram_range,
207
+ stop_words='english',
208
+ sublinear_tf=True,
209
+ norm='l2'
210
+ )
211
+
212
+ # Feature selection
213
+ feature_selector = SelectKBest(
214
+ score_func=chi2,
215
+ k=self.feature_selection_k
216
+ )
217
+
218
+ # Create pipeline
219
+ pipeline = Pipeline([
220
+ ('preprocess', text_preprocessor),
221
+ ('vectorize', vectorizer),
222
+ ('feature_select', feature_selector),
223
+ ('model', None) # Will be set during training
224
+ ])
225
+
226
+ return pipeline
227
+
228
+ def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
229
+ """Comprehensive model evaluation with multiple metrics"""
230
+ logger.info("Starting comprehensive model evaluation...")
231
+
232
+ # Predictions
233
+ y_pred = model.predict(X_test)
234
+ y_pred_proba = model.predict_proba(X_test)[:, 1]
235
+
236
+ # Basic metrics
237
+ metrics = {
238
+ 'accuracy': float(accuracy_score(y_test, y_pred)),
239
+ 'precision': float(precision_score(y_test, y_pred, average='weighted')),
240
+ 'recall': float(recall_score(y_test, y_pred, average='weighted')),
241
+ 'f1': float(f1_score(y_test, y_pred, average='weighted')),
242
+ 'roc_auc': float(roc_auc_score(y_test, y_pred_proba))
243
+ }
244
+
245
+ # Confusion matrix
246
+ cm = confusion_matrix(y_test, y_pred)
247
+ metrics['confusion_matrix'] = cm.tolist()
248
+
249
+ # Classification report
250
+ class_report = classification_report(y_test, y_pred, output_dict=True)
251
+ metrics['classification_report'] = class_report
252
+
253
+ # Cross-validation scores if training data provided
254
+ if X_train is not None and y_train is not None:
255
+ try:
256
+ cv_scores = cross_val_score(
257
+ model, X_train, y_train,
258
+ cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
259
+ scoring='f1_weighted'
260
+ )
261
+ metrics['cv_scores'] = {
262
+ 'mean': float(cv_scores.mean()),
263
+ 'std': float(cv_scores.std()),
264
+ 'scores': cv_scores.tolist()
265
+ }
266
+ except Exception as e:
267
+ logger.warning(f"Cross-validation failed: {e}")
268
+ metrics['cv_scores'] = None
269
+
270
+ # Feature importance (if available)
271
+ try:
272
+ if hasattr(model, 'feature_importances_'):
273
+ feature_importance = model.feature_importances_
274
+ metrics['feature_importance_stats'] = {
275
+ 'mean': float(feature_importance.mean()),
276
+ 'std': float(feature_importance.std()),
277
+ 'top_features': feature_importance.argsort()[-10:][::-1].tolist()
278
+ }
279
+ elif hasattr(model, 'coef_'):
280
+ coefficients = model.coef_[0]
281
+ metrics['coefficient_stats'] = {
282
+ 'mean': float(coefficients.mean()),
283
+ 'std': float(coefficients.std()),
284
+ 'top_positive': coefficients.argsort()[-10:][::-1].tolist(),
285
+ 'top_negative': coefficients.argsort()[:10].tolist()
286
+ }
287
+ except Exception as e:
288
+ logger.warning(f"Feature importance extraction failed: {e}")
289
+
290
+ # Model complexity metrics
291
+ try:
292
+ # Training accuracy for overfitting detection
293
+ if X_train is not None and y_train is not None:
294
+ y_train_pred = model.predict(X_train)
295
+ train_accuracy = accuracy_score(y_train, y_train_pred)
296
+ metrics['train_accuracy'] = float(train_accuracy)
297
+ metrics['overfitting_score'] = float(train_accuracy - metrics['accuracy'])
298
+ except Exception as e:
299
+ logger.warning(f"Overfitting detection failed: {e}")
300
+
301
+ return metrics
302
+
303
+ def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
304
+ """Perform hyperparameter tuning with cross-validation"""
305
+ logger.info(f"Starting hyperparameter tuning for {model_name}...")
306
+
307
+ try:
308
+ # Set the model in the pipeline
309
+ pipeline.set_params(model=self.models[model_name]['model'])
310
+
311
+ # Get parameter grid
312
+ param_grid = self.models[model_name]['param_grid']
313
+
314
+ # Create GridSearchCV
315
+ grid_search = GridSearchCV(
316
+ pipeline,
317
+ param_grid,
318
+ cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
319
+ scoring='f1_weighted',
320
+ n_jobs=-1,
321
+ verbose=1
322
+ )
323
+
324
+ # Fit grid search
325
+ grid_search.fit(X_train, y_train)
326
+
327
+ # Extract results
328
+ tuning_results = {
329
+ 'best_params': grid_search.best_params_,
330
+ 'best_score': float(grid_search.best_score_),
331
+ 'best_estimator': grid_search.best_estimator_,
332
+ 'cv_results': {
333
+ 'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
334
+ 'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
335
+ 'params': grid_search.cv_results_['params']
336
+ }
337
+ }
338
+
339
+ logger.info(f"Hyperparameter tuning completed for {model_name}")
340
+ logger.info(f"Best score: {grid_search.best_score_:.4f}")
341
+ logger.info(f"Best params: {grid_search.best_params_}")
342
+
343
+ return grid_search.best_estimator_, tuning_results
344
+
345
+ except Exception as e:
346
+ logger.error(f"Hyperparameter tuning failed for {model_name}: {str(e)}")
347
+ # Return basic model if tuning fails
348
+ pipeline.set_params(model=self.models[model_name]['model'])
349
+ pipeline.fit(X_train, y_train)
350
+ return pipeline, {'error': str(e)}
351
+
352
+ def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
353
+ """Train and evaluate multiple models"""
354
+ logger.info("Starting model training and evaluation...")
355
+
356
+ results = {}
357
+
358
+ for model_name in self.models.keys():
359
+ logger.info(f"Training {model_name}...")
360
+
361
+ try:
362
+ # Create pipeline
363
+ pipeline = self.create_preprocessing_pipeline()
364
+
365
+ # Hyperparameter tuning
366
+ best_model, tuning_results = self.hyperparameter_tuning(
367
+ pipeline, X_train, y_train, model_name
368
+ )
369
+
370
+ # Comprehensive evaluation
371
+ evaluation_metrics = self.comprehensive_evaluation(
372
+ best_model, X_test, y_test, X_train, y_train
373
+ )
374
+
375
+ # Store results
376
+ results[model_name] = {
377
+ 'model': best_model,
378
+ 'tuning_results': tuning_results,
379
+ 'evaluation_metrics': evaluation_metrics,
380
+ 'training_time': datetime.now().isoformat()
381
+ }
382
+
383
+ logger.info(f"Model {model_name} - F1: {evaluation_metrics['f1']:.4f}, "
384
+ f"Accuracy: {evaluation_metrics['accuracy']:.4f}")
385
+
386
+ except Exception as e:
387
+ logger.error(f"Training failed for {model_name}: {str(e)}")
388
+ results[model_name] = {'error': str(e)}
389
+
390
+ return results
391
+
392
+ def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
393
+ """Select the best performing model"""
394
+ logger.info("Selecting best model...")
395
+
396
+ best_model_name = None
397
+ best_model = None
398
+ best_score = -1
399
+ best_metrics = None
400
+
401
+ for model_name, result in results.items():
402
+ if 'error' in result:
403
+ continue
404
+
405
+ # Use F1 score as primary metric
406
+ f1_score = result['evaluation_metrics']['f1']
407
+
408
+ if f1_score > best_score:
409
+ best_score = f1_score
410
+ best_model_name = model_name
411
+ best_model = result['model']
412
+ best_metrics = result['evaluation_metrics']
413
+
414
+ if best_model_name is None:
415
+ raise ValueError("No models trained successfully")
416
+
417
+ logger.info(f"Best model: {best_model_name} with F1 score: {best_score:.4f}")
418
+ return best_model_name, best_model, best_metrics
419
+
420
+ def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
421
+ """Save model artifacts and metadata"""
422
+ try:
423
+ logger.info("Saving model artifacts...")
424
+
425
+ # Save the full pipeline
426
+ joblib.dump(model, self.pipeline_path)
427
+
428
+ # Save individual components for backward compatibility
429
+ joblib.dump(model.named_steps['model'], self.model_path)
430
+ joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
431
+
432
+ # Generate data hash
433
+ data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
434
+
435
+ # Create metadata
436
+ metadata = {
437
+ 'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
438
+ 'model_type': model_name,
439
+ 'data_version': data_hash,
440
+ 'train_size': metrics.get('train_accuracy', 'Unknown'),
441
+ 'test_size': len(metrics.get('confusion_matrix', [[0]])[0]) if 'confusion_matrix' in metrics else 'Unknown',
442
+ 'test_accuracy': metrics['accuracy'],
443
+ 'test_f1': metrics['f1'],
444
+ 'test_precision': metrics['precision'],
445
+ 'test_recall': metrics['recall'],
446
+ 'test_roc_auc': metrics['roc_auc'],
447
+ 'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
448
+ 'cv_score_mean': metrics.get('cv_scores', {}).get('mean', 'Unknown'),
449
+ 'cv_score_std': metrics.get('cv_scores', {}).get('std', 'Unknown'),
450
+ 'timestamp': datetime.now().isoformat(),
451
+ 'training_config': {
452
+ 'test_size': self.test_size,
453
+ 'validation_size': self.validation_size,
454
+ 'cv_folds': self.cv_folds,
455
+ 'max_features': self.max_features,
456
+ 'ngram_range': self.ngram_range,
457
+ 'feature_selection_k': self.feature_selection_k
458
+ }
459
+ }
460
+
461
+ # Save metadata
462
+ with open(self.metadata_path, 'w') as f:
463
+ json.dump(metadata, f, indent=2)
464
+
465
+ logger.info(f"Model artifacts saved successfully")
466
+ logger.info(f"Model path: {self.model_path}")
467
+ logger.info(f"Vectorizer path: {self.vectorizer_path}")
468
+ logger.info(f"Pipeline path: {self.pipeline_path}")
469
+ logger.info(f"Metadata path: {self.metadata_path}")
470
+
471
+ return True
472
+
473
+ except Exception as e:
474
+ logger.error(f"Failed to save model artifacts: {str(e)}")
475
+ return False
476
+
477
+ def save_evaluation_results(self, results: Dict) -> bool:
478
+ """Save comprehensive evaluation results"""
479
+ try:
480
+ # Clean results for JSON serialization
481
+ clean_results = {}
482
+ for model_name, result in results.items():
483
+ if 'error' in result:
484
+ clean_results[model_name] = result
485
+ else:
486
+ clean_results[model_name] = {
487
+ 'tuning_results': {
488
+ k: v for k, v in result['tuning_results'].items()
489
+ if k != 'best_estimator'
490
+ },
491
+ 'evaluation_metrics': result['evaluation_metrics'],
492
+ 'training_time': result['training_time']
493
+ }
494
+
495
+ # Save results
496
+ with open(self.evaluation_path, 'w') as f:
497
+ json.dump(clean_results, f, indent=2, default=str)
498
+
499
+ logger.info(f"Evaluation results saved to {self.evaluation_path}")
500
+ return True
501
+
502
+ except Exception as e:
503
+ logger.error(f"Failed to save evaluation results: {str(e)}")
504
+ return False
505
+
506
+ def train_model(self, data_path: str = None) -> Tuple[bool, str]:
507
+ """Main training function with comprehensive pipeline"""
508
+ try:
509
+ logger.info("Starting model training pipeline...")
510
+
511
+ # Override data path if provided
512
+ if data_path:
513
+ self.data_path = Path(data_path)
514
+
515
+ # Load and validate data
516
+ success, df, message = self.load_and_validate_data()
517
+ if not success:
518
+ return False, message
519
+
520
+ # Prepare data
521
+ X = df['text'].values
522
+ y = df['label'].values
523
+
524
+ # Train-test split
525
+ X_train, X_test, y_train, y_test = train_test_split(
526
+ X, y,
527
+ test_size=self.test_size,
528
+ stratify=y,
529
+ random_state=self.random_state
530
+ )
531
+
532
+ logger.info(f"Data split: {len(X_train)} train, {len(X_test)} test")
533
+
534
+ # Train and evaluate models
535
+ results = self.train_and_evaluate_models(X_train, X_test, y_train, y_test)
536
+
537
+ # Select best model
538
+ best_model_name, best_model, best_metrics = self.select_best_model(results)
539
+
540
+ # Save model artifacts
541
+ if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
542
+ return False, "Failed to save model artifacts"
543
+
544
+ # Save evaluation results
545
+ self.save_evaluation_results(results)
546
+
547
+ success_message = (
548
+ f"Model training completed successfully. "
549
+ f"Best model: {best_model_name} "
550
+ f"(F1: {best_metrics['f1']:.4f}, Accuracy: {best_metrics['accuracy']:.4f})"
551
+ )
552
+
553
+ logger.info(success_message)
554
+ return True, success_message
555
+
556
+ except Exception as e:
557
+ error_message = f"Model training failed: {str(e)}"
558
+ logger.error(error_message)
559
+ return False, error_message
560
 
561
  def main():
562
+ """Main execution function"""
563
+ trainer = RobustModelTrainer()
564
+ success, message = trainer.train_model()
565
+
566
+ if success:
567
+ print(f"✅ {message}")
568
+ else:
569
+ print(f"❌ {message}")
570
+ exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
  if __name__ == "__main__":
573
+ main()