Commit
·
3a989cc
1
Parent(s):
ce7aca5
Update model/train.py
Browse files- model/train.py +41 -1
model/train.py
CHANGED
@@ -111,7 +111,7 @@ class ProgressTracker:
|
|
111 |
filled_length = int(bar_length * self.current_step // self.total_steps)
|
112 |
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
113 |
|
114 |
-
# Print progress
|
115 |
status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
|
116 |
if step_name:
|
117 |
status_msg += f" | {step_name}"
|
@@ -120,6 +120,18 @@ class ProgressTracker:
|
|
120 |
|
121 |
print(status_msg, end='', flush=True)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# Store step time for better estimation
|
124 |
if len(self.step_times) >= 3: # Keep last 3 step times for moving average
|
125 |
self.step_times.pop(0)
|
@@ -773,9 +785,37 @@ def main():
|
|
773 |
# Parse command line arguments
|
774 |
parser = argparse.ArgumentParser(description='Train fake news detection model')
|
775 |
parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
|
|
|
776 |
args = parser.parse_args()
|
777 |
|
778 |
trainer = RobustModelTrainer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
success, message = trainer.train_model(data_path=args.data_path)
|
780 |
|
781 |
if success:
|
|
|
111 |
filled_length = int(bar_length * self.current_step // self.total_steps)
|
112 |
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
113 |
|
114 |
+
# Print progress (this will be visible in Streamlit logs)
|
115 |
status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
|
116 |
if step_name:
|
117 |
status_msg += f" | {step_name}"
|
|
|
120 |
|
121 |
print(status_msg, end='', flush=True)
|
122 |
|
123 |
+
# Also output JSON for Streamlit parsing (if needed)
|
124 |
+
progress_json = {
|
125 |
+
"type": "progress",
|
126 |
+
"step": self.current_step,
|
127 |
+
"total": self.total_steps,
|
128 |
+
"percentage": progress_pct,
|
129 |
+
"eta": str(eta) if eta != "calculating..." else None,
|
130 |
+
"step_name": step_name,
|
131 |
+
"elapsed": elapsed
|
132 |
+
}
|
133 |
+
print(f"\nPROGRESS_JSON: {json.dumps(progress_json)}")
|
134 |
+
|
135 |
# Store step time for better estimation
|
136 |
if len(self.step_times) >= 3: # Keep last 3 step times for moving average
|
137 |
self.step_times.pop(0)
|
|
|
785 |
# Parse command line arguments
|
786 |
parser = argparse.ArgumentParser(description='Train fake news detection model')
|
787 |
parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
|
788 |
+
parser.add_argument('--config_path', type=str, help='Path to training configuration JSON file')
|
789 |
args = parser.parse_args()
|
790 |
|
791 |
trainer = RobustModelTrainer()
|
792 |
+
|
793 |
+
# Load custom configuration if provided
|
794 |
+
if args.config_path and Path(args.config_path).exists():
|
795 |
+
try:
|
796 |
+
with open(args.config_path, 'r') as f:
|
797 |
+
config = json.load(f)
|
798 |
+
|
799 |
+
# Apply configuration
|
800 |
+
trainer.test_size = config.get('test_size', trainer.test_size)
|
801 |
+
trainer.cv_folds = config.get('cv_folds', trainer.cv_folds)
|
802 |
+
trainer.max_features = config.get('max_features', trainer.max_features)
|
803 |
+
trainer.ngram_range = tuple(config.get('ngram_range', trainer.ngram_range))
|
804 |
+
|
805 |
+
# Filter models if specified
|
806 |
+
selected_models = config.get('selected_models')
|
807 |
+
if selected_models and len(selected_models) < len(trainer.models):
|
808 |
+
all_models = trainer.models.copy()
|
809 |
+
trainer.models = {k: v for k, v in all_models.items() if k in selected_models}
|
810 |
+
|
811 |
+
# Update feature selection based on max_features
|
812 |
+
trainer.feature_selection_k = min(trainer.feature_selection_k, trainer.max_features)
|
813 |
+
|
814 |
+
logger.info(f"Applied custom configuration: {config}")
|
815 |
+
|
816 |
+
except Exception as e:
|
817 |
+
logger.warning(f"Failed to load configuration: {e}, using defaults")
|
818 |
+
|
819 |
success, message = trainer.train_model(data_path=args.data_path)
|
820 |
|
821 |
if success:
|