AlvaroMros commited on
Commit
1e04613
·
1 Parent(s): b7da0df

Integrate model saving into prediction pipeline

Browse files

Moved model training and saving logic from save_model.py into the PredictionPipeline class. Updated config to define MODELS_DIR and refactored predict_new.py to use the new model directory. Removed the now-redundant save_model.py script.

src/config.py CHANGED
@@ -1,11 +1,7 @@
1
  import os
2
 
3
  OUTPUT_DIR = 'output'
 
 
4
  FIGHTS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.csv')
5
  FIGHTERS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.csv')
6
- MODEL_RESULTS_PATH = os.path.join(OUTPUT_DIR, 'model_results.json')
7
-
8
- # JSON files (temporary)
9
- EVENTS_JSON_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.json')
10
- FIGHTERS_JSON_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.json')
11
-
 
1
  import os
2
 
3
  OUTPUT_DIR = 'output'
4
+ MODELS_DIR = os.path.join(OUTPUT_DIR, 'models')
5
+ MODEL_RESULTS_PATH = os.path.join(OUTPUT_DIR, 'model_results.json')
6
  FIGHTS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.csv')
7
  FIGHTERS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.csv')
 
 
 
 
 
 
src/predict/pipeline.py CHANGED
@@ -4,8 +4,9 @@ import sys
4
  from datetime import datetime
5
  from collections import OrderedDict
6
  import json
 
7
 
8
- from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH
9
  from .models import BaseModel
10
 
11
  class PredictionPipeline:
@@ -43,7 +44,7 @@ class PredictionPipeline:
43
  print(f"Testing on the last {num_test_events} events.")
44
 
45
  def run(self, detailed_report=True):
46
- """Executes the full pipeline: load, train, evaluate, and report."""
47
  self._load_and_split_data()
48
 
49
  eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
@@ -91,6 +92,36 @@ class PredictionPipeline:
91
  else:
92
  self._report_summary()
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def _report_summary(self):
95
  """Prints a concise summary of model performance."""
96
  print("\n\n--- Prediction Pipeline Summary ---")
 
4
  from datetime import datetime
5
  from collections import OrderedDict
6
  import json
7
+ import joblib
8
 
9
+ from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR
10
  from .models import BaseModel
11
 
12
  class PredictionPipeline:
 
44
  print(f"Testing on the last {num_test_events} events.")
45
 
46
  def run(self, detailed_report=True):
47
+ """Executes the full pipeline: load, train, evaluate, report and save models."""
48
  self._load_and_split_data()
49
 
50
  eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
 
92
  else:
93
  self._report_summary()
94
 
95
+ self._train_and_save_models()
96
+
97
+ def _train_and_save_models(self):
98
+ """Trains all models on the full dataset and saves them."""
99
+ print("\n\n--- Training and Saving All Models on Full Dataset ---")
100
+
101
+ if not os.path.exists(FIGHTS_CSV_PATH):
102
+ print(f"Error: Fights data not found at '{FIGHTS_CSV_PATH}'. Cannot save models.")
103
+ return
104
+
105
+ with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
106
+ all_fights = list(csv.DictReader(f))
107
+
108
+ print(f"Training models on all {len(all_fights)} available fights...")
109
+
110
+ if not os.path.exists(MODELS_DIR):
111
+ os.makedirs(MODELS_DIR)
112
+ print(f"Created directory: {MODELS_DIR}")
113
+
114
+ for model in self.models:
115
+ model_name = model.__class__.__name__
116
+ print(f"\n--- Training: {model_name} ---")
117
+ model.train(all_fights)
118
+
119
+ # Sanitize and save the model
120
+ file_name = f"{model_name}.joblib"
121
+ save_path = os.path.join(MODELS_DIR, file_name)
122
+ joblib.dump(model, save_path)
123
+ print(f"Model saved successfully to {save_path}")
124
+
125
  def _report_summary(self):
126
  """Prints a concise summary of model performance."""
127
  print("\n\n--- Prediction Pipeline Summary ---")
src/predict/predict_new.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import joblib
4
  from datetime import datetime
5
 
6
- from ..config import OUTPUT_DIR
7
 
8
  def predict_new_fight(fighter1_name, fighter2_name, model_path):
9
  """
@@ -45,7 +45,7 @@ if __name__ == '__main__':
45
  parser.add_argument(
46
  '--model_path',
47
  type=str,
48
- default=os.path.join(OUTPUT_DIR, 'XGBoostModel.joblib'),
49
  help="Path to the saved model file."
50
  )
51
  args = parser.parse_args()
 
3
  import joblib
4
  from datetime import datetime
5
 
6
+ from ..config import MODELS_DIR
7
 
8
  def predict_new_fight(fighter1_name, fighter2_name, model_path):
9
  """
 
45
  parser.add_argument(
46
  '--model_path',
47
  type=str,
48
+ default=os.path.join(MODELS_DIR, 'XGBoostModel.joblib'),
49
  help="Path to the saved model file."
50
  )
51
  args = parser.parse_args()
src/predict/save_model.py DELETED
@@ -1,53 +0,0 @@
1
- import argparse
2
- import os
3
- import joblib
4
- import pandas as pd
5
-
6
- from ..config import FIGHTS_CSV_PATH, OUTPUT_DIR
7
- import src.predict.models as models
8
-
9
- def save_model(model_name):
10
- """
11
- Trains a specified model on the entire dataset and saves it to a file.
12
-
13
- :param model_name: The name of the model class to train (e.g., 'XGBoostModel').
14
- """
15
- print(f"--- Training and Saving Model: {model_name} ---")
16
-
17
- # 1. Get the model class from the models module
18
- try:
19
- ModelClass = getattr(models, model_name)
20
- except AttributeError:
21
- print(f"Error: Model '{model_name}' not found in src/predict/models.py")
22
- return
23
-
24
- model = ModelClass()
25
-
26
- # 2. Load all available fights for training
27
- if not os.path.exists(FIGHTS_CSV_PATH):
28
- raise FileNotFoundError(f"Fights data not found at '{FIGHTS_CSV_PATH}'.")
29
-
30
- all_fights = pd.read_csv(FIGHTS_CSV_PATH).to_dict('records')
31
- print(f"Training model on all {len(all_fights)} available fights...")
32
-
33
- # 3. Train the model
34
- model.train(all_fights)
35
-
36
- # 4. Save the entire trained model object
37
- model_name_to_save=f"{model_name}.joblib"
38
- save_path = os.path.join(OUTPUT_DIR, model_name_to_save)
39
- joblib.dump(model, save_path)
40
-
41
- print(f"\nModel saved successfully to {save_path}")
42
-
43
- if __name__ == '__main__':
44
- parser = argparse.ArgumentParser(description="Train and save a prediction model.")
45
- parser.add_argument(
46
- '--model',
47
- type=str,
48
- default='XGBoostModel',
49
- help="The name of the model class to train and save."
50
- )
51
- args = parser.parse_args()
52
-
53
- save_model(args.model)