Ahmedik95316 commited on
Commit
3a989cc
·
1 Parent(s): ce7aca5

Update model/train.py

Browse files
Files changed (1) hide show
  1. model/train.py +41 -1
model/train.py CHANGED
@@ -111,7 +111,7 @@ class ProgressTracker:
111
  filled_length = int(bar_length * self.current_step // self.total_steps)
112
  bar = '█' * filled_length + '░' * (bar_length - filled_length)
113
 
114
- # Print progress
115
  status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
116
  if step_name:
117
  status_msg += f" | {step_name}"
@@ -120,6 +120,18 @@ class ProgressTracker:
120
 
121
  print(status_msg, end='', flush=True)
122
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Store step time for better estimation
124
  if len(self.step_times) >= 3: # Keep last 3 step times for moving average
125
  self.step_times.pop(0)
@@ -773,9 +785,37 @@ def main():
773
  # Parse command line arguments
774
  parser = argparse.ArgumentParser(description='Train fake news detection model')
775
  parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
 
776
  args = parser.parse_args()
777
 
778
  trainer = RobustModelTrainer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  success, message = trainer.train_model(data_path=args.data_path)
780
 
781
  if success:
 
111
  filled_length = int(bar_length * self.current_step // self.total_steps)
112
  bar = '█' * filled_length + '░' * (bar_length - filled_length)
113
 
114
+ # Print progress (this will be visible in Streamlit logs)
115
  status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
116
  if step_name:
117
  status_msg += f" | {step_name}"
 
120
 
121
  print(status_msg, end='', flush=True)
122
 
123
+ # Also output JSON for Streamlit parsing (if needed)
124
+ progress_json = {
125
+ "type": "progress",
126
+ "step": self.current_step,
127
+ "total": self.total_steps,
128
+ "percentage": progress_pct,
129
+ "eta": str(eta) if eta != "calculating..." else None,
130
+ "step_name": step_name,
131
+ "elapsed": elapsed
132
+ }
133
+ print(f"\nPROGRESS_JSON: {json.dumps(progress_json)}")
134
+
135
  # Store step time for better estimation
136
  if len(self.step_times) >= 3: # Keep last 3 step times for moving average
137
  self.step_times.pop(0)
 
785
  # Parse command line arguments
786
  parser = argparse.ArgumentParser(description='Train fake news detection model')
787
  parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
788
+ parser.add_argument('--config_path', type=str, help='Path to training configuration JSON file')
789
  args = parser.parse_args()
790
 
791
  trainer = RobustModelTrainer()
792
+
793
+ # Load custom configuration if provided
794
+ if args.config_path and Path(args.config_path).exists():
795
+ try:
796
+ with open(args.config_path, 'r') as f:
797
+ config = json.load(f)
798
+
799
+ # Apply configuration
800
+ trainer.test_size = config.get('test_size', trainer.test_size)
801
+ trainer.cv_folds = config.get('cv_folds', trainer.cv_folds)
802
+ trainer.max_features = config.get('max_features', trainer.max_features)
803
+ trainer.ngram_range = tuple(config.get('ngram_range', trainer.ngram_range))
804
+
805
+ # Filter models if specified
806
+ selected_models = config.get('selected_models')
807
+ if selected_models and len(selected_models) < len(trainer.models):
808
+ all_models = trainer.models.copy()
809
+ trainer.models = {k: v for k, v in all_models.items() if k in selected_models}
810
+
811
+ # Update feature selection based on max_features
812
+ trainer.feature_selection_k = min(trainer.feature_selection_k, trainer.max_features)
813
+
814
+ logger.info(f"Applied custom configuration: {config}")
815
+
816
+ except Exception as e:
817
+ logger.warning(f"Failed to load configuration: {e}, using defaults")
818
+
819
  success, message = trainer.train_model(data_path=args.data_path)
820
 
821
  if success: