Spaces:
Running
Running
Commit
·
371767b
1
Parent(s):
eb615ca
Fix import
Browse files- .gitignore +1 -0
- src/predict/main.py +4 -4
- src/predict/models.py +2 -2
- src/predict/pipeline.py +6 -16
.gitignore
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
*__pycache__/
|
| 2 |
example_event.html
|
| 3 |
web/
|
|
|
|
|
|
| 1 |
*__pycache__/
|
| 2 |
example_event.html
|
| 3 |
web/
|
| 4 |
+
mlruns/
|
src/predict/main.py
CHANGED
|
@@ -16,9 +16,9 @@ from .models import (
|
|
| 16 |
MODELS_TO_RUN = [
|
| 17 |
EloBaselineModel(),
|
| 18 |
LogisticRegressionModel(),
|
| 19 |
-
XGBoostModel(),
|
| 20 |
-
SVCModel(),
|
| 21 |
-
RandomForestModel(),
|
| 22 |
BernoulliNBModel(),
|
| 23 |
LGBMModel(),
|
| 24 |
]
|
|
@@ -58,7 +58,7 @@ def main():
|
|
| 58 |
parser.add_argument(
|
| 59 |
'--kfold',
|
| 60 |
action='store_true',
|
| 61 |
-
help='Run
|
| 62 |
)
|
| 63 |
args = parser.parse_args()
|
| 64 |
|
|
|
|
| 16 |
MODELS_TO_RUN = [
|
| 17 |
EloBaselineModel(),
|
| 18 |
LogisticRegressionModel(),
|
| 19 |
+
# XGBoostModel(),
|
| 20 |
+
# SVCModel(),
|
| 21 |
+
# RandomForestModel(),
|
| 22 |
BernoulliNBModel(),
|
| 23 |
LGBMModel(),
|
| 24 |
]
|
|
|
|
| 58 |
parser.add_argument(
|
| 59 |
'--kfold',
|
| 60 |
action='store_true',
|
| 61 |
+
help='Run 10-fold CV instead of standard split.'
|
| 62 |
)
|
| 63 |
args = parser.parse_args()
|
| 64 |
|
src/predict/models.py
CHANGED
|
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import pandas as pd
|
| 5 |
-
from typing import Dict, Any, Optional
|
| 6 |
from sklearn.linear_model import LogisticRegression
|
| 7 |
from sklearn.svm import SVC
|
| 8 |
from sklearn.naive_bayes import BernoulliNB
|
|
@@ -88,7 +88,7 @@ class BaseMLModel(BaseModel):
|
|
| 88 |
self.fighters_df = None
|
| 89 |
self.fighter_histories = {}
|
| 90 |
|
| 91 |
-
def train(self, train_fights:
|
| 92 |
"""
|
| 93 |
Trains the machine learning model. This involves loading fighter data,
|
| 94 |
pre-calculating histories, and fitting the model on the preprocessed data.
|
|
|
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import pandas as pd
|
| 5 |
+
from typing import Dict, Any, Optional, List
|
| 6 |
from sklearn.linear_model import LogisticRegression
|
| 7 |
from sklearn.svm import SVC
|
| 8 |
from sklearn.naive_bayes import BernoulliNB
|
|
|
|
| 88 |
self.fighters_df = None
|
| 89 |
self.fighter_histories = {}
|
| 90 |
|
| 91 |
+
def train(self, train_fights: List[Dict[str, Any]]) -> None:
|
| 92 |
"""
|
| 93 |
Trains the machine learning model. This involves loading fighter data,
|
| 94 |
pre-calculating histories, and fitting the model on the preprocessed data.
|
src/predict/pipeline.py
CHANGED
|
@@ -291,27 +291,17 @@ class PredictionPipeline:
|
|
| 291 |
# Train and evaluate
|
| 292 |
model.train(train_set)
|
| 293 |
correct = 0
|
| 294 |
-
total_fights = 0
|
| 295 |
for fight in test_set:
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
correct += 1
|
| 300 |
-
total_fights += 1
|
| 301 |
|
| 302 |
-
acc = correct /
|
| 303 |
fold_results[model_name] = acc
|
| 304 |
|
| 305 |
-
# Log metrics and
|
| 306 |
mlflow.log_metric(f"accuracy_{model_name}", acc)
|
| 307 |
-
mlflow.
|
| 308 |
-
|
| 309 |
-
# Register the model with MLflow to appear in Models tab
|
| 310 |
-
mlflow.sklearn.log_model(
|
| 311 |
-
model,
|
| 312 |
-
f"model_{model_name}",
|
| 313 |
-
registered_model_name=f"{model_name}_UFC_Model"
|
| 314 |
-
)
|
| 315 |
|
| 316 |
all_fold_metrics.append(fold_results)
|
| 317 |
|
|
|
|
| 291 |
# Train and evaluate
|
| 292 |
model.train(train_set)
|
| 293 |
correct = 0
|
|
|
|
| 294 |
for fight in test_set:
|
| 295 |
+
prediction = model.predict(fight)
|
| 296 |
+
if prediction.get('winner') == fight['winner']:
|
| 297 |
+
correct += 1
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
acc = correct / len(test_set) if test_set else 0.0
|
| 300 |
fold_results[model_name] = acc
|
| 301 |
|
| 302 |
+
# Log metrics and model artifact
|
| 303 |
mlflow.log_metric(f"accuracy_{model_name}", acc)
|
| 304 |
+
mlflow.sklearn.log_model(model, f"model_{model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
all_fold_metrics.append(fold_results)
|
| 307 |
|