Commit
Β·
c29bcf3
1
Parent(s):
63bd8f9
Update model/train.py
Browse filesFixed error with model not saving after training
- When GridSearchCV tries to pickle the pipeline during hyperparameter tuning, it fails because lambda functions are not serializable
- Fixed issue in pipeline saving code that tries to save the pipeline before the model is set
- model/train.py +54 -42
model/train.py
CHANGED
@@ -25,10 +25,9 @@ import hashlib
|
|
25 |
from datetime import datetime
|
26 |
from typing import Dict, Tuple, Optional, Any
|
27 |
import warnings
|
|
|
28 |
warnings.filterwarnings('ignore')
|
29 |
|
30 |
-
# Scikit-learn imports
|
31 |
-
|
32 |
# Configure logging
|
33 |
logging.basicConfig(
|
34 |
level=logging.INFO,
|
@@ -41,6 +40,41 @@ logging.basicConfig(
|
|
41 |
logger = logging.getLogger(__name__)
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class RobustModelTrainer:
|
45 |
"""Production-ready model trainer with comprehensive evaluation and validation"""
|
46 |
|
@@ -169,37 +203,12 @@ class RobustModelTrainer:
|
|
169 |
logger.error(error_msg)
|
170 |
return False, None, error_msg
|
171 |
|
172 |
-
def preprocess_text(self, text):
|
173 |
-
"""Advanced text preprocessing"""
|
174 |
-
import re
|
175 |
-
|
176 |
-
# Convert to string
|
177 |
-
text = str(text)
|
178 |
-
|
179 |
-
# Remove URLs
|
180 |
-
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
|
181 |
-
|
182 |
-
# Remove email addresses
|
183 |
-
text = re.sub(r'\S+@\S+', '', text)
|
184 |
-
|
185 |
-
# Remove excessive punctuation
|
186 |
-
text = re.sub(r'[!]{2,}', '!', text)
|
187 |
-
text = re.sub(r'[?]{2,}', '?', text)
|
188 |
-
text = re.sub(r'[.]{3,}', '...', text)
|
189 |
-
|
190 |
-
# Remove non-alphabetic characters except spaces and basic punctuation
|
191 |
-
text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
|
192 |
-
|
193 |
-
# Remove excessive whitespace
|
194 |
-
text = re.sub(r'\s+', ' ', text)
|
195 |
-
|
196 |
-
return text.strip().lower()
|
197 |
-
|
198 |
def create_preprocessing_pipeline(self) -> Pipeline:
|
199 |
-
"""Create advanced preprocessing pipeline"""
|
200 |
-
|
|
|
201 |
text_preprocessor = FunctionTransformer(
|
202 |
-
func=
|
203 |
validate=False
|
204 |
)
|
205 |
|
@@ -228,13 +237,6 @@ class RobustModelTrainer:
|
|
228 |
('model', None) # Will be set during training
|
229 |
])
|
230 |
|
231 |
-
# After creating the pipeline
|
232 |
-
joblib.dump(pipeline, "/tmp/pipeline.pkl") # Save complete pipeline
|
233 |
-
# Individual model
|
234 |
-
joblib.dump(pipeline.named_steps['model'], "/tmp/model.pkl")
|
235 |
-
# Individual vectorizer
|
236 |
-
joblib.dump(pipeline.named_steps['vectorize'], "/tmp/vectorizer.pkl")
|
237 |
-
|
238 |
return pipeline
|
239 |
|
240 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
@@ -441,10 +443,20 @@ class RobustModelTrainer:
|
|
441 |
|
442 |
# Save the full pipeline
|
443 |
joblib.dump(model, self.pipeline_path)
|
|
|
444 |
|
445 |
# Save individual components for backward compatibility
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
# Generate data hash
|
450 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
@@ -479,7 +491,7 @@ class RobustModelTrainer:
|
|
479 |
with open(self.metadata_path, 'w') as f:
|
480 |
json.dump(metadata, f, indent=2)
|
481 |
|
482 |
-
logger.info(f"Model artifacts saved successfully")
|
483 |
logger.info(f"Model path: {self.model_path}")
|
484 |
logger.info(f"Vectorizer path: {self.vectorizer_path}")
|
485 |
logger.info(f"Pipeline path: {self.pipeline_path}")
|
@@ -592,4 +604,4 @@ def main():
|
|
592 |
|
593 |
|
594 |
if __name__ == "__main__":
|
595 |
-
main()
|
|
|
25 |
from datetime import datetime
|
26 |
from typing import Dict, Tuple, Optional, Any
|
27 |
import warnings
|
28 |
+
import re
|
29 |
warnings.filterwarnings('ignore')
|
30 |
|
|
|
|
|
31 |
# Configure logging
|
32 |
logging.basicConfig(
|
33 |
level=logging.INFO,
|
|
|
40 |
logger = logging.getLogger(__name__)
|
41 |
|
42 |
|
43 |
+
def preprocess_text_function(texts):
|
44 |
+
"""
|
45 |
+
Standalone function for text preprocessing - pickle-safe
|
46 |
+
"""
|
47 |
+
def clean_single_text(text):
|
48 |
+
# Convert to string
|
49 |
+
text = str(text)
|
50 |
+
|
51 |
+
# Remove URLs
|
52 |
+
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
|
53 |
+
|
54 |
+
# Remove email addresses
|
55 |
+
text = re.sub(r'\S+@\S+', '', text)
|
56 |
+
|
57 |
+
# Remove excessive punctuation
|
58 |
+
text = re.sub(r'[!]{2,}', '!', text)
|
59 |
+
text = re.sub(r'[?]{2,}', '?', text)
|
60 |
+
text = re.sub(r'[.]{3,}', '...', text)
|
61 |
+
|
62 |
+
# Remove non-alphabetic characters except spaces and basic punctuation
|
63 |
+
text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
|
64 |
+
|
65 |
+
# Remove excessive whitespace
|
66 |
+
text = re.sub(r'\s+', ' ', text)
|
67 |
+
|
68 |
+
return text.strip().lower()
|
69 |
+
|
70 |
+
# Process all texts
|
71 |
+
processed = []
|
72 |
+
for text in texts:
|
73 |
+
processed.append(clean_single_text(text))
|
74 |
+
|
75 |
+
return processed
|
76 |
+
|
77 |
+
|
78 |
class RobustModelTrainer:
|
79 |
"""Production-ready model trainer with comprehensive evaluation and validation"""
|
80 |
|
|
|
203 |
logger.error(error_msg)
|
204 |
return False, None, error_msg
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def create_preprocessing_pipeline(self) -> Pipeline:
|
207 |
+
"""Create advanced preprocessing pipeline - pickle-safe"""
|
208 |
+
|
209 |
+
# Use the standalone function instead of lambda
|
210 |
text_preprocessor = FunctionTransformer(
|
211 |
+
func=preprocess_text_function, # β
Pickle-safe function reference
|
212 |
validate=False
|
213 |
)
|
214 |
|
|
|
237 |
('model', None) # Will be set during training
|
238 |
])
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
return pipeline
|
241 |
|
242 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
|
|
443 |
|
444 |
# Save the full pipeline
|
445 |
joblib.dump(model, self.pipeline_path)
|
446 |
+
logger.info(f"β
Saved pipeline to {self.pipeline_path}")
|
447 |
|
448 |
# Save individual components for backward compatibility
|
449 |
+
if hasattr(model, 'named_steps') and 'model' in model.named_steps:
|
450 |
+
joblib.dump(model.named_steps['model'], self.model_path)
|
451 |
+
logger.info(f"β
Saved model to {self.model_path}")
|
452 |
+
else:
|
453 |
+
logger.warning("β Could not extract model component")
|
454 |
+
|
455 |
+
if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
|
456 |
+
joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
|
457 |
+
logger.info(f"β
Saved vectorizer to {self.vectorizer_path}")
|
458 |
+
else:
|
459 |
+
logger.warning("β Could not extract vectorizer component")
|
460 |
|
461 |
# Generate data hash
|
462 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
|
|
491 |
with open(self.metadata_path, 'w') as f:
|
492 |
json.dump(metadata, f, indent=2)
|
493 |
|
494 |
+
logger.info(f"β
Model artifacts saved successfully")
|
495 |
logger.info(f"Model path: {self.model_path}")
|
496 |
logger.info(f"Vectorizer path: {self.vectorizer_path}")
|
497 |
logger.info(f"Pipeline path: {self.pipeline_path}")
|
|
|
604 |
|
605 |
|
606 |
if __name__ == "__main__":
|
607 |
+
main()
|