Ahmedik95316 commited on
Commit
aff6c9c
·
1 Parent(s): 620e5bd

Update monitor/monitor_drift.py

Browse files

Adding Automated Retraining Triggers

Files changed (1) hide show
  1. monitor/monitor_drift.py +484 -0
monitor/monitor_drift.py CHANGED
@@ -667,6 +667,490 @@ class AdvancedDriftMonitor:
667
  logger.error(f"Drift monitoring failed: {e}")
668
  return None
669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  def monitor_drift():
671
  """Main function for external calls"""
672
  monitor = AdvancedDriftMonitor()
 
667
  logger.error(f"Drift monitoring failed: {e}")
668
  return None
669
 
670
+ def setup_automation_config(self):
671
+ """Setup automation-specific configuration"""
672
+ self.automation_config = {
673
+ 'retraining_thresholds': {
674
+ 'drift_score': 0.2,
675
+ 'consecutive_detections': 3,
676
+ 'performance_drop': 0.05,
677
+ 'data_volume_threshold': 1000,
678
+ 'time_since_last_training': timedelta(days=7)
679
+ },
680
+ 'monitoring_schedule': {
681
+ 'check_interval': timedelta(hours=6),
682
+ 'force_check_interval': timedelta(days=1),
683
+ 'max_monitoring_failures': 5
684
+ },
685
+ 'emergency_thresholds': {
686
+ 'critical_drift_score': 0.4,
687
+ 'critical_performance_drop': 0.15,
688
+ 'emergency_action_required': True
689
+ },
690
+ 'data_quality_thresholds': {
691
+ 'min_samples_for_detection': 100,
692
+ 'min_samples_for_retraining': 500,
693
+ 'data_freshness_hours': 24
694
+ }
695
+ }
696
+
697
+ def check_retraining_triggers(self, drift_results: Dict = None) -> Dict:
698
+ """Check if retraining should be triggered based on multiple criteria"""
699
+ try:
700
+ trigger_results = {
701
+ 'should_retrain': False,
702
+ 'trigger_reason': None,
703
+ 'urgency': 'none',
704
+ 'triggers_detected': [],
705
+ 'data_quality_check': {},
706
+ 'recommendations': []
707
+ }
708
+
709
+ # Perform drift monitoring if not provided
710
+ if drift_results is None:
711
+ reference_df, current_df = self.load_and_prepare_data()
712
+ if reference_df is None or current_df is None:
713
+ trigger_results['trigger_reason'] = 'insufficient_data'
714
+ return trigger_results
715
+
716
+ drift_results = self.comprehensive_drift_detection(reference_df, current_df)
717
+ if 'error' in drift_results:
718
+ trigger_results['trigger_reason'] = f"drift_detection_error: {drift_results['error']}"
719
+ return trigger_results
720
+
721
+ # Check drift-based triggers
722
+ drift_triggers = self.check_drift_triggers(drift_results)
723
+ trigger_results['triggers_detected'].extend(drift_triggers)
724
+
725
+ # Check data volume triggers
726
+ volume_triggers = self.check_data_volume_triggers()
727
+ trigger_results['triggers_detected'].extend(volume_triggers)
728
+
729
+ # Check time-based triggers
730
+ time_triggers = self.check_time_based_triggers()
731
+ trigger_results['triggers_detected'].extend(time_triggers)
732
+
733
+ # Check data quality
734
+ trigger_results['data_quality_check'] = self.check_data_quality()
735
+
736
+ # Determine if retraining should be triggered
737
+ trigger_results = self.evaluate_retraining_decision(trigger_results, drift_results)
738
+
739
+ # Save trigger evaluation
740
+ self.save_trigger_evaluation(trigger_results)
741
+
742
+ return trigger_results
743
+
744
+ except Exception as e:
745
+ logger.error(f"Retraining trigger check failed: {e}")
746
+ return {
747
+ 'should_retrain': False,
748
+ 'trigger_reason': f'trigger_check_error: {str(e)}',
749
+ 'urgency': 'none',
750
+ 'triggers_detected': [],
751
+ 'error': str(e)
752
+ }
753
+
754
+ def check_drift_triggers(self, drift_results: Dict) -> List[Dict]:
755
+ """Check drift-based retraining triggers"""
756
+ triggers = []
757
+
758
+ # Overall drift score trigger
759
+ overall_score = drift_results.get('overall_drift_score', 0)
760
+ if overall_score > self.automation_config['retraining_thresholds']['drift_score']:
761
+ triggers.append({
762
+ 'type': 'drift_score',
763
+ 'severity': 'high' if overall_score > self.automation_config['emergency_thresholds']['critical_drift_score'] else 'medium',
764
+ 'value': overall_score,
765
+ 'threshold': self.automation_config['retraining_thresholds']['drift_score'],
766
+ 'message': f"Drift score {overall_score:.3f} exceeds threshold {self.automation_config['retraining_thresholds']['drift_score']}"
767
+ })
768
+
769
+ # Performance degradation trigger
770
+ perf_results = drift_results.get('individual_methods', {}).get('performance_drift', {})
771
+ if 'performance_drop' in perf_results:
772
+ perf_drop = perf_results['performance_drop']
773
+ if perf_drop > self.automation_config['retraining_thresholds']['performance_drop']:
774
+ triggers.append({
775
+ 'type': 'performance_degradation',
776
+ 'severity': 'critical' if perf_drop > self.automation_config['emergency_thresholds']['critical_performance_drop'] else 'high',
777
+ 'value': perf_drop,
778
+ 'threshold': self.automation_config['retraining_thresholds']['performance_drop'],
779
+ 'message': f"Performance drop {perf_drop:.3f} exceeds threshold"
780
+ })
781
+
782
+ # Consecutive detection trigger
783
+ consecutive_detections = self.count_consecutive_drift_detections()
784
+ if consecutive_detections >= self.automation_config['retraining_thresholds']['consecutive_detections']:
785
+ triggers.append({
786
+ 'type': 'consecutive_detections',
787
+ 'severity': 'medium',
788
+ 'value': consecutive_detections,
789
+ 'threshold': self.automation_config['retraining_thresholds']['consecutive_detections'],
790
+ 'message': f"Drift detected in {consecutive_detections} consecutive monitoring cycles"
791
+ })
792
+
793
+ return triggers
794
+
795
+ def check_data_volume_triggers(self) -> List[Dict]:
796
+ """Check data volume-based triggers"""
797
+ triggers = []
798
+
799
+ try:
800
+ # Count new data since last training
801
+ new_data_count = self.count_new_data_since_training()
802
+
803
+ if new_data_count >= self.automation_config['retraining_thresholds']['data_volume_threshold']:
804
+ triggers.append({
805
+ 'type': 'data_volume',
806
+ 'severity': 'low',
807
+ 'value': new_data_count,
808
+ 'threshold': self.automation_config['retraining_thresholds']['data_volume_threshold'],
809
+ 'message': f"Accumulated {new_data_count} new samples since last training"
810
+ })
811
+
812
+ return triggers
813
+
814
+ except Exception as e:
815
+ logger.warning(f"Data volume trigger check failed: {e}")
816
+ return []
817
+
818
+ def check_time_based_triggers(self) -> List[Dict]:
819
+ """Check time-based retraining triggers"""
820
+ triggers = []
821
+
822
+ try:
823
+ # Get last training time
824
+ last_training_time = self.get_last_training_time()
825
+
826
+ if last_training_time:
827
+ time_since_training = datetime.now() - last_training_time
828
+ threshold = self.automation_config['retraining_thresholds']['time_since_last_training']
829
+
830
+ if time_since_training > threshold:
831
+ triggers.append({
832
+ 'type': 'time_since_training',
833
+ 'severity': 'low',
834
+ 'value': time_since_training.days,
835
+ 'threshold': threshold.days,
836
+ 'message': f"Last training was {time_since_training.days} days ago"
837
+ })
838
+
839
+ return triggers
840
+
841
+ except Exception as e:
842
+ logger.warning(f"Time-based trigger check failed: {e}")
843
+ return []
844
+
845
+ def check_data_quality(self) -> Dict:
846
+ """Check data quality for retraining"""
847
+ quality_check = {
848
+ 'sufficient_data': False,
849
+ 'data_freshness': False,
850
+ 'data_balance': False,
851
+ 'overall_quality': 'poor',
852
+ 'issues': []
853
+ }
854
+
855
+ try:
856
+ # Load current data
857
+ _, current_df = self.load_and_prepare_data()
858
+
859
+ if current_df is None or len(current_df) == 0:
860
+ quality_check['issues'].append('No current data available')
861
+ return quality_check
862
+
863
+ # Check data volume
864
+ min_samples = self.automation_config['data_quality_thresholds']['min_samples_for_retraining']
865
+ if len(current_df) >= min_samples:
866
+ quality_check['sufficient_data'] = True
867
+ else:
868
+ quality_check['issues'].append(f'Insufficient data: {len(current_df)} < {min_samples}')
869
+
870
+ # Check data freshness
871
+ if 'timestamp' in current_df.columns:
872
+ try:
873
+ current_df['timestamp'] = pd.to_datetime(current_df['timestamp'])
874
+ latest_data = current_df['timestamp'].max()
875
+ freshness_threshold = datetime.now() - timedelta(
876
+ hours=self.automation_config['data_quality_thresholds']['data_freshness_hours']
877
+ )
878
+
879
+ if latest_data > freshness_threshold:
880
+ quality_check['data_freshness'] = True
881
+ else:
882
+ quality_check['issues'].append('Data is not fresh enough')
883
+ except:
884
+ quality_check['issues'].append('Cannot determine data freshness')
885
+
886
+ # Check data balance if labels available
887
+ if 'label' in current_df.columns:
888
+ label_counts = current_df['label'].value_counts()
889
+ if len(label_counts) > 1:
890
+ balance_ratio = label_counts.min() / label_counts.max()
891
+ if balance_ratio > 0.3: # At least 30% minority class
892
+ quality_check['data_balance'] = True
893
+ else:
894
+ quality_check['issues'].append(f'Data imbalance: ratio {balance_ratio:.2f}')
895
+
896
+ # Overall quality assessment
897
+ quality_score = sum([
898
+ quality_check['sufficient_data'],
899
+ quality_check['data_freshness'],
900
+ quality_check['data_balance']
901
+ ])
902
+
903
+ if quality_score >= 3:
904
+ quality_check['overall_quality'] = 'excellent'
905
+ elif quality_score >= 2:
906
+ quality_check['overall_quality'] = 'good'
907
+ elif quality_score >= 1:
908
+ quality_check['overall_quality'] = 'fair'
909
+ else:
910
+ quality_check['overall_quality'] = 'poor'
911
+
912
+ return quality_check
913
+
914
+ except Exception as e:
915
+ logger.error(f"Data quality check failed: {e}")
916
+ quality_check['issues'].append(f'Quality check error: {str(e)}')
917
+ return quality_check
918
+
919
+ def evaluate_retraining_decision(self, trigger_results: Dict, drift_results: Dict) -> Dict:
920
+ """Evaluate whether retraining should be triggered"""
921
+
922
+ triggers = trigger_results['triggers_detected']
923
+ data_quality = trigger_results['data_quality_check']
924
+
925
+ # Count trigger types and severities
926
+ critical_triggers = [t for t in triggers if t['severity'] == 'critical']
927
+ high_triggers = [t for t in triggers if t['severity'] == 'high']
928
+ medium_triggers = [t for t in triggers if t['severity'] == 'medium']
929
+
930
+ # Decision logic
931
+ should_retrain = False
932
+ urgency = 'none'
933
+ reason = None
934
+ recommendations = []
935
+
936
+ # Critical triggers - immediate retraining
937
+ if critical_triggers:
938
+ should_retrain = True
939
+ urgency = 'critical'
940
+ reason = f"Critical triggers detected: {[t['type'] for t in critical_triggers]}"
941
+ recommendations.extend([
942
+ "URGENT: Critical model degradation detected",
943
+ "Stop current model serving if possible",
944
+ "Initiate emergency retraining immediately"
945
+ ])
946
+
947
+ # High priority triggers - urgent retraining
948
+ elif high_triggers:
949
+ if data_quality['overall_quality'] in ['good', 'excellent']:
950
+ should_retrain = True
951
+ urgency = 'high'
952
+ reason = f"High priority triggers with good data quality: {[t['type'] for t in high_triggers]}"
953
+ recommendations.extend([
954
+ "High priority retraining recommended",
955
+ "Schedule retraining within 24 hours"
956
+ ])
957
+ else:
958
+ recommendations.extend([
959
+ "High priority triggers detected but data quality insufficient",
960
+ "Improve data quality before retraining"
961
+ ])
962
+
963
+ # Medium priority triggers - scheduled retraining
964
+ elif len(medium_triggers) >= 2 or len(triggers) >= 3:
965
+ if data_quality['overall_quality'] in ['good', 'excellent', 'fair']:
966
+ should_retrain = True
967
+ urgency = 'medium'
968
+ reason = f"Multiple triggers detected: {[t['type'] for t in triggers]}"
969
+ recommendations.extend([
970
+ "Multiple retraining indicators detected",
971
+ "Schedule retraining within next maintenance window"
972
+ ])
973
+
974
+ # Single medium or low priority triggers
975
+ elif triggers:
976
+ recommendations.extend([
977
+ "Some retraining indicators detected",
978
+ "Monitor closely and prepare for retraining",
979
+ f"Triggers: {[t['type'] for t in triggers]}"
980
+ ])
981
+
982
+ # Update results
983
+ trigger_results.update({
984
+ 'should_retrain': should_retrain,
985
+ 'urgency': urgency,
986
+ 'trigger_reason': reason,
987
+ 'recommendations': recommendations
988
+ })
989
+
990
+ return trigger_results
991
+
992
+ def count_consecutive_drift_detections(self) -> int:
993
+ """Count consecutive drift detections from historical data"""
994
+ try:
995
+ if not self.drift_log_path.exists():
996
+ return 0
997
+
998
+ with open(self.drift_log_path, 'r') as f:
999
+ logs = json.load(f)
1000
+
1001
+ if not logs:
1002
+ return 0
1003
+
1004
+ # Sort by timestamp and count consecutive detections
1005
+ logs_sorted = sorted(logs, key=lambda x: x.get('timestamp', ''))
1006
+ consecutive_count = 0
1007
+
1008
+ for log_entry in reversed(logs_sorted[-10:]): # Check last 10 entries
1009
+ if log_entry.get('overall_drift_detected', False):
1010
+ consecutive_count += 1
1011
+ else:
1012
+ break
1013
+
1014
+ return consecutive_count
1015
+
1016
+ except Exception as e:
1017
+ logger.warning(f"Failed to count consecutive detections: {e}")
1018
+ return 0
1019
+
1020
+ def count_new_data_since_training(self) -> int:
1021
+ """Count new data samples since last training"""
1022
+ try:
1023
+ last_training_time = self.get_last_training_time()
1024
+ if not last_training_time:
1025
+ return 0
1026
+
1027
+ # Count data from current sources
1028
+ total_count = 0
1029
+
1030
+ for data_path in [self.current_data_path, self.generated_data_path]:
1031
+ if data_path.exists():
1032
+ df = pd.read_csv(data_path)
1033
+ if 'timestamp' in df.columns:
1034
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
1035
+ new_data = df[df['timestamp'] > last_training_time]
1036
+ total_count += len(new_data)
1037
+ else:
1038
+ # If no timestamp, assume all data is new
1039
+ total_count += len(df)
1040
+
1041
+ return total_count
1042
+
1043
+ except Exception as e:
1044
+ logger.warning(f"Failed to count new data: {e}")
1045
+ return 0
1046
+
1047
+ def get_last_training_time(self) -> Optional[datetime]:
1048
+ """Get timestamp of last model training"""
1049
+ try:
1050
+ # Check model metadata
1051
+ metadata_path = self.model_dir / "metadata.json"
1052
+ if metadata_path.exists():
1053
+ with open(metadata_path, 'r') as f:
1054
+ metadata = json.load(f)
1055
+
1056
+ timestamp_str = metadata.get('timestamp')
1057
+ if timestamp_str:
1058
+ return datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
1059
+
1060
+ # Fallback to model file modification time
1061
+ for model_path in [self.pipeline_path, self.model_path]:
1062
+ if model_path.exists():
1063
+ return datetime.fromtimestamp(model_path.stat().st_mtime)
1064
+
1065
+ return None
1066
+
1067
+ except Exception as e:
1068
+ logger.warning(f"Failed to get last training time: {e}")
1069
+ return None
1070
+
1071
+ def save_trigger_evaluation(self, trigger_results: Dict):
1072
+ """Save trigger evaluation results"""
1073
+ try:
1074
+ trigger_log_path = self.logs_dir / "retraining_triggers.json"
1075
+
1076
+ # Load existing logs
1077
+ logs = []
1078
+ if trigger_log_path.exists():
1079
+ try:
1080
+ with open(trigger_log_path, 'r') as f:
1081
+ logs = json.load(f)
1082
+ except:
1083
+ logs = []
1084
+
1085
+ # Add timestamp and save
1086
+ trigger_results['evaluation_timestamp'] = datetime.now().isoformat()
1087
+ logs.append(trigger_results)
1088
+
1089
+ # Keep only last 100 evaluations
1090
+ if len(logs) > 100:
1091
+ logs = logs[-100:]
1092
+
1093
+ with open(trigger_log_path, 'w') as f:
1094
+ json.dump(logs, f, indent=2)
1095
+
1096
+ logger.info(f"Trigger evaluation saved to {trigger_log_path}")
1097
+
1098
+ except Exception as e:
1099
+ logger.error(f"Failed to save trigger evaluation: {e}")
1100
+
1101
+ def get_automation_status(self) -> Dict:
1102
+ """Get current automation status and recent trigger evaluations"""
1103
+ try:
1104
+ status = {
1105
+ 'automation_active': True,
1106
+ 'last_drift_check': None,
1107
+ 'last_trigger_evaluation': None,
1108
+ 'recent_triggers': [],
1109
+ 'data_quality_status': {},
1110
+ 'next_scheduled_check': None
1111
+ }
1112
+
1113
+ # Get last drift check
1114
+ if self.drift_log_path.exists():
1115
+ try:
1116
+ with open(self.drift_log_path, 'r') as f:
1117
+ logs = json.load(f)
1118
+ if logs:
1119
+ status['last_drift_check'] = logs[-1].get('timestamp')
1120
+ except:
1121
+ pass
1122
+
1123
+ # Get recent trigger evaluations
1124
+ trigger_log_path = self.logs_dir / "retraining_triggers.json"
1125
+ if trigger_log_path.exists():
1126
+ try:
1127
+ with open(trigger_log_path, 'r') as f:
1128
+ trigger_logs = json.load(f)
1129
+
1130
+ if trigger_logs:
1131
+ status['last_trigger_evaluation'] = trigger_logs[-1].get('evaluation_timestamp')
1132
+ status['recent_triggers'] = trigger_logs[-5:] # Last 5 evaluations
1133
+ except:
1134
+ pass
1135
+
1136
+ # Get current data quality
1137
+ status['data_quality_status'] = self.check_data_quality()
1138
+
1139
+ return status
1140
+
1141
+ except Exception as e:
1142
+ logger.error(f"Failed to get automation status: {e}")
1143
+ return {'automation_active': False, 'error': str(e)}
1144
+
1145
+ # Add to __init__ method
1146
+ def __init__(self):
1147
+ self.setup_paths()
1148
+ self.setup_drift_config()
1149
+ self.setup_automation_config()
1150
+ self.setup_drift_methods()
1151
+ self.historical_data = self.load_historical_data()
1152
+
1153
+
1154
  def monitor_drift():
1155
  """Main function for external calls"""
1156
  monitor = AdvancedDriftMonitor()