billion_row_challenge / mlflow_tracker.py
AnnsKhan's picture
comi
baa5edc
import mlflow
import yaml
import os
class MLflowTracker:
"""
A reusable MLflow tracking class that reads configuration from a YAML file.
This class sets up the MLflow experiment and run, and exposes methods to log parameters,
metrics, and artifacts.
"""
def __init__(self, config_file="mlflow_config.yaml"):
# Load configuration from the YAML file.
if not os.path.exists(config_file):
raise FileNotFoundError(f"Config file '{config_file}' not found.")
with open(config_file, "r") as f:
self.config = yaml.safe_load(f)
# Set up configuration parameters
self.experiment_name = self.config.get("experiment_name", "Default_Experiment")
self.run_name = self.config.get("run_name", "Default_Run")
self.tracking_uri = self.config.get("tracking_uri", None)
self.metrics_to_track = self.config.get("metrics", [])
# Set tracking URI if provided
if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)
# Set the experiment
mlflow.set_experiment(self.experiment_name)
# Start the run
self.run = mlflow.start_run(run_name=self.run_name)
print(f"MLflow run started: Experiment='{self.experiment_name}', Run='{self.run_name}'")
def log_param(self, key, value):
"""Log a single parameter."""
mlflow.log_param(key, value)
def log_params(self, params: dict):
"""Log multiple parameters from a dictionary."""
mlflow.log_params(params)
def log_metric(self, key, value, step=None):
"""Log a single metric. Optionally include a step value."""
mlflow.log_metric(key, value, step=step)
def log_metrics(self, metrics: dict, step=None):
"""Log multiple metrics from a dictionary."""
for key, value in metrics.items():
self.log_metric(key, value, step=step)
def log_artifact(self, file_path, artifact_path=None):
"""Log an artifact (file) to MLflow."""
mlflow.log_artifact(file_path, artifact_path=artifact_path)
def end_run(self):
"""End the current MLflow run."""
mlflow.end_run()
print("MLflow run ended.")
# Example usage (can be removed or placed in a separate test script):
if __name__ == "__main__":
tracker = MLflowTracker("mlflow_config.yaml")
tracker.log_param("example_param", 123)
tracker.log_metric("example_metric", 0.95)
tracker.end_run()