Spaces:
Running
Running
Commit
·
1e04613
1
Parent(s):
b7da0df
Integrate model saving into prediction pipeline
Browse filesMoved model training and saving logic from save_model.py into the PredictionPipeline class. Updated config to define MODELS_DIR and refactored predict_new.py to use the new model directory. Removed the now-redundant save_model.py script.
- src/config.py +2 -6
- src/predict/pipeline.py +33 -2
- src/predict/predict_new.py +2 -2
- src/predict/save_model.py +0 -53
src/config.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
import os
|
2 |
|
3 |
OUTPUT_DIR = 'output'
|
|
|
|
|
4 |
FIGHTS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.csv')
|
5 |
FIGHTERS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.csv')
|
6 |
-
MODEL_RESULTS_PATH = os.path.join(OUTPUT_DIR, 'model_results.json')
|
7 |
-
|
8 |
-
# JSON files (temporary)
|
9 |
-
EVENTS_JSON_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.json')
|
10 |
-
FIGHTERS_JSON_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.json')
|
11 |
-
|
|
|
1 |
import os
|
2 |
|
3 |
OUTPUT_DIR = 'output'
|
4 |
+
MODELS_DIR = os.path.join(OUTPUT_DIR, 'models')
|
5 |
+
MODEL_RESULTS_PATH = os.path.join(OUTPUT_DIR, 'model_results.json')
|
6 |
FIGHTS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fights.csv')
|
7 |
FIGHTERS_CSV_PATH = os.path.join(OUTPUT_DIR, 'ufc_fighters.csv')
|
|
|
|
|
|
|
|
|
|
|
|
src/predict/pipeline.py
CHANGED
@@ -4,8 +4,9 @@ import sys
|
|
4 |
from datetime import datetime
|
5 |
from collections import OrderedDict
|
6 |
import json
|
|
|
7 |
|
8 |
-
from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH
|
9 |
from .models import BaseModel
|
10 |
|
11 |
class PredictionPipeline:
|
@@ -43,7 +44,7 @@ class PredictionPipeline:
|
|
43 |
print(f"Testing on the last {num_test_events} events.")
|
44 |
|
45 |
def run(self, detailed_report=True):
|
46 |
-
"""Executes the full pipeline: load, train, evaluate, and
|
47 |
self._load_and_split_data()
|
48 |
|
49 |
eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
|
@@ -91,6 +92,36 @@ class PredictionPipeline:
|
|
91 |
else:
|
92 |
self._report_summary()
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def _report_summary(self):
|
95 |
"""Prints a concise summary of model performance."""
|
96 |
print("\n\n--- Prediction Pipeline Summary ---")
|
|
|
4 |
from datetime import datetime
|
5 |
from collections import OrderedDict
|
6 |
import json
|
7 |
+
import joblib
|
8 |
|
9 |
+
from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR
|
10 |
from .models import BaseModel
|
11 |
|
12 |
class PredictionPipeline:
|
|
|
44 |
print(f"Testing on the last {num_test_events} events.")
|
45 |
|
46 |
def run(self, detailed_report=True):
|
47 |
+
"""Executes the full pipeline: load, train, evaluate, report and save models."""
|
48 |
self._load_and_split_data()
|
49 |
|
50 |
eval_fights = [f for f in self.test_fights if f['winner'] not in ["Draw", "NC", ""]]
|
|
|
92 |
else:
|
93 |
self._report_summary()
|
94 |
|
95 |
+
self._train_and_save_models()
|
96 |
+
|
97 |
+
def _train_and_save_models(self):
|
98 |
+
"""Trains all models on the full dataset and saves them."""
|
99 |
+
print("\n\n--- Training and Saving All Models on Full Dataset ---")
|
100 |
+
|
101 |
+
if not os.path.exists(FIGHTS_CSV_PATH):
|
102 |
+
print(f"Error: Fights data not found at '{FIGHTS_CSV_PATH}'. Cannot save models.")
|
103 |
+
return
|
104 |
+
|
105 |
+
with open(FIGHTS_CSV_PATH, 'r', encoding='utf-8') as f:
|
106 |
+
all_fights = list(csv.DictReader(f))
|
107 |
+
|
108 |
+
print(f"Training models on all {len(all_fights)} available fights...")
|
109 |
+
|
110 |
+
if not os.path.exists(MODELS_DIR):
|
111 |
+
os.makedirs(MODELS_DIR)
|
112 |
+
print(f"Created directory: {MODELS_DIR}")
|
113 |
+
|
114 |
+
for model in self.models:
|
115 |
+
model_name = model.__class__.__name__
|
116 |
+
print(f"\n--- Training: {model_name} ---")
|
117 |
+
model.train(all_fights)
|
118 |
+
|
119 |
+
# Sanitize and save the model
|
120 |
+
file_name = f"{model_name}.joblib"
|
121 |
+
save_path = os.path.join(MODELS_DIR, file_name)
|
122 |
+
joblib.dump(model, save_path)
|
123 |
+
print(f"Model saved successfully to {save_path}")
|
124 |
+
|
125 |
def _report_summary(self):
|
126 |
"""Prints a concise summary of model performance."""
|
127 |
print("\n\n--- Prediction Pipeline Summary ---")
|
src/predict/predict_new.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
import joblib
|
4 |
from datetime import datetime
|
5 |
|
6 |
-
from ..config import
|
7 |
|
8 |
def predict_new_fight(fighter1_name, fighter2_name, model_path):
|
9 |
"""
|
@@ -45,7 +45,7 @@ if __name__ == '__main__':
|
|
45 |
parser.add_argument(
|
46 |
'--model_path',
|
47 |
type=str,
|
48 |
-
default=os.path.join(
|
49 |
help="Path to the saved model file."
|
50 |
)
|
51 |
args = parser.parse_args()
|
|
|
3 |
import joblib
|
4 |
from datetime import datetime
|
5 |
|
6 |
+
from ..config import MODELS_DIR
|
7 |
|
8 |
def predict_new_fight(fighter1_name, fighter2_name, model_path):
|
9 |
"""
|
|
|
45 |
parser.add_argument(
|
46 |
'--model_path',
|
47 |
type=str,
|
48 |
+
default=os.path.join(MODELS_DIR, 'XGBoostModel.joblib'),
|
49 |
help="Path to the saved model file."
|
50 |
)
|
51 |
args = parser.parse_args()
|
src/predict/save_model.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import joblib
|
4 |
-
import pandas as pd
|
5 |
-
|
6 |
-
from ..config import FIGHTS_CSV_PATH, OUTPUT_DIR
|
7 |
-
import src.predict.models as models
|
8 |
-
|
9 |
-
def save_model(model_name):
|
10 |
-
"""
|
11 |
-
Trains a specified model on the entire dataset and saves it to a file.
|
12 |
-
|
13 |
-
:param model_name: The name of the model class to train (e.g., 'XGBoostModel').
|
14 |
-
"""
|
15 |
-
print(f"--- Training and Saving Model: {model_name} ---")
|
16 |
-
|
17 |
-
# 1. Get the model class from the models module
|
18 |
-
try:
|
19 |
-
ModelClass = getattr(models, model_name)
|
20 |
-
except AttributeError:
|
21 |
-
print(f"Error: Model '{model_name}' not found in src/predict/models.py")
|
22 |
-
return
|
23 |
-
|
24 |
-
model = ModelClass()
|
25 |
-
|
26 |
-
# 2. Load all available fights for training
|
27 |
-
if not os.path.exists(FIGHTS_CSV_PATH):
|
28 |
-
raise FileNotFoundError(f"Fights data not found at '{FIGHTS_CSV_PATH}'.")
|
29 |
-
|
30 |
-
all_fights = pd.read_csv(FIGHTS_CSV_PATH).to_dict('records')
|
31 |
-
print(f"Training model on all {len(all_fights)} available fights...")
|
32 |
-
|
33 |
-
# 3. Train the model
|
34 |
-
model.train(all_fights)
|
35 |
-
|
36 |
-
# 4. Save the entire trained model object
|
37 |
-
model_name_to_save=f"{model_name}.joblib"
|
38 |
-
save_path = os.path.join(OUTPUT_DIR, model_name_to_save)
|
39 |
-
joblib.dump(model, save_path)
|
40 |
-
|
41 |
-
print(f"\nModel saved successfully to {save_path}")
|
42 |
-
|
43 |
-
if __name__ == '__main__':
|
44 |
-
parser = argparse.ArgumentParser(description="Train and save a prediction model.")
|
45 |
-
parser.add_argument(
|
46 |
-
'--model',
|
47 |
-
type=str,
|
48 |
-
default='XGBoostModel',
|
49 |
-
help="The name of the model class to train and save."
|
50 |
-
)
|
51 |
-
args = parser.parse_args()
|
52 |
-
|
53 |
-
save_model(args.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|