Ahmedik95316 commited on
Commit
c29bcf3
Β·
1 Parent(s): 63bd8f9

Update model/train.py

Browse files

Fixed error with model not saving after training

- When GridSearchCV tries to pickle the pipeline during hyperparameter tuning, it fails because lambda functions are not serializable
- Fixed issue in pipeline saving code that tries to save the pipeline before the model is set

Files changed (1) hide show
  1. model/train.py +54 -42
model/train.py CHANGED
@@ -25,10 +25,9 @@ import hashlib
25
  from datetime import datetime
26
  from typing import Dict, Tuple, Optional, Any
27
  import warnings
 
28
  warnings.filterwarnings('ignore')
29
 
30
- # Scikit-learn imports
31
-
32
  # Configure logging
33
  logging.basicConfig(
34
  level=logging.INFO,
@@ -41,6 +40,41 @@ logging.basicConfig(
41
  logger = logging.getLogger(__name__)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class RobustModelTrainer:
45
  """Production-ready model trainer with comprehensive evaluation and validation"""
46
 
@@ -169,37 +203,12 @@ class RobustModelTrainer:
169
  logger.error(error_msg)
170
  return False, None, error_msg
171
 
172
- def preprocess_text(self, text):
173
- """Advanced text preprocessing"""
174
- import re
175
-
176
- # Convert to string
177
- text = str(text)
178
-
179
- # Remove URLs
180
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
181
-
182
- # Remove email addresses
183
- text = re.sub(r'\S+@\S+', '', text)
184
-
185
- # Remove excessive punctuation
186
- text = re.sub(r'[!]{2,}', '!', text)
187
- text = re.sub(r'[?]{2,}', '?', text)
188
- text = re.sub(r'[.]{3,}', '...', text)
189
-
190
- # Remove non-alphabetic characters except spaces and basic punctuation
191
- text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
192
-
193
- # Remove excessive whitespace
194
- text = re.sub(r'\s+', ' ', text)
195
-
196
- return text.strip().lower()
197
-
198
  def create_preprocessing_pipeline(self) -> Pipeline:
199
- """Create advanced preprocessing pipeline"""
200
- # Text preprocessing
 
201
  text_preprocessor = FunctionTransformer(
202
- func=lambda x: [self.preprocess_text(text) for text in x],
203
  validate=False
204
  )
205
 
@@ -228,13 +237,6 @@ class RobustModelTrainer:
228
  ('model', None) # Will be set during training
229
  ])
230
 
231
- # After creating the pipeline
232
- joblib.dump(pipeline, "/tmp/pipeline.pkl") # Save complete pipeline
233
- # Individual model
234
- joblib.dump(pipeline.named_steps['model'], "/tmp/model.pkl")
235
- # Individual vectorizer
236
- joblib.dump(pipeline.named_steps['vectorize'], "/tmp/vectorizer.pkl")
237
-
238
  return pipeline
239
 
240
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
@@ -441,10 +443,20 @@ class RobustModelTrainer:
441
 
442
  # Save the full pipeline
443
  joblib.dump(model, self.pipeline_path)
 
444
 
445
  # Save individual components for backward compatibility
446
- joblib.dump(model.named_steps['model'], self.model_path)
447
- joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
 
 
 
 
 
 
 
 
 
448
 
449
  # Generate data hash
450
  data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
@@ -479,7 +491,7 @@ class RobustModelTrainer:
479
  with open(self.metadata_path, 'w') as f:
480
  json.dump(metadata, f, indent=2)
481
 
482
- logger.info(f"Model artifacts saved successfully")
483
  logger.info(f"Model path: {self.model_path}")
484
  logger.info(f"Vectorizer path: {self.vectorizer_path}")
485
  logger.info(f"Pipeline path: {self.pipeline_path}")
@@ -592,4 +604,4 @@ def main():
592
 
593
 
594
  if __name__ == "__main__":
595
- main()
 
25
  from datetime import datetime
26
  from typing import Dict, Tuple, Optional, Any
27
  import warnings
28
+ import re
29
  warnings.filterwarnings('ignore')
30
 
 
 
31
  # Configure logging
32
  logging.basicConfig(
33
  level=logging.INFO,
 
40
  logger = logging.getLogger(__name__)
41
 
42
 
43
+ def preprocess_text_function(texts):
44
+ """
45
+ Standalone function for text preprocessing - pickle-safe
46
+ """
47
+ def clean_single_text(text):
48
+ # Convert to string
49
+ text = str(text)
50
+
51
+ # Remove URLs
52
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
53
+
54
+ # Remove email addresses
55
+ text = re.sub(r'\S+@\S+', '', text)
56
+
57
+ # Remove excessive punctuation
58
+ text = re.sub(r'[!]{2,}', '!', text)
59
+ text = re.sub(r'[?]{2,}', '?', text)
60
+ text = re.sub(r'[.]{3,}', '...', text)
61
+
62
+ # Remove non-alphabetic characters except spaces and basic punctuation
63
+ text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
64
+
65
+ # Remove excessive whitespace
66
+ text = re.sub(r'\s+', ' ', text)
67
+
68
+ return text.strip().lower()
69
+
70
+ # Process all texts
71
+ processed = []
72
+ for text in texts:
73
+ processed.append(clean_single_text(text))
74
+
75
+ return processed
76
+
77
+
78
  class RobustModelTrainer:
79
  """Production-ready model trainer with comprehensive evaluation and validation"""
80
 
 
203
  logger.error(error_msg)
204
  return False, None, error_msg
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def create_preprocessing_pipeline(self) -> Pipeline:
207
+ """Create advanced preprocessing pipeline - pickle-safe"""
208
+
209
+ # Use the standalone function instead of lambda
210
  text_preprocessor = FunctionTransformer(
211
+ func=preprocess_text_function, # βœ… Pickle-safe function reference
212
  validate=False
213
  )
214
 
 
237
  ('model', None) # Will be set during training
238
  ])
239
 
 
 
 
 
 
 
 
240
  return pipeline
241
 
242
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
 
443
 
444
  # Save the full pipeline
445
  joblib.dump(model, self.pipeline_path)
446
+ logger.info(f"βœ… Saved pipeline to {self.pipeline_path}")
447
 
448
  # Save individual components for backward compatibility
449
+ if hasattr(model, 'named_steps') and 'model' in model.named_steps:
450
+ joblib.dump(model.named_steps['model'], self.model_path)
451
+ logger.info(f"βœ… Saved model to {self.model_path}")
452
+ else:
453
+ logger.warning("❌ Could not extract model component")
454
+
455
+ if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
456
+ joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
457
+ logger.info(f"βœ… Saved vectorizer to {self.vectorizer_path}")
458
+ else:
459
+ logger.warning("❌ Could not extract vectorizer component")
460
 
461
  # Generate data hash
462
  data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
 
491
  with open(self.metadata_path, 'w') as f:
492
  json.dump(metadata, f, indent=2)
493
 
494
+ logger.info(f"βœ… Model artifacts saved successfully")
495
  logger.info(f"Model path: {self.model_path}")
496
  logger.info(f"Vectorizer path: {self.vectorizer_path}")
497
  logger.info(f"Pipeline path: {self.pipeline_path}")
 
604
 
605
 
606
  if __name__ == "__main__":
607
+ main()