Spaces:
Sleeping
Sleeping
File size: 2,540 Bytes
baa5edc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import mlflow
import yaml
import os
class MLflowTracker:
"""
A reusable MLflow tracking class that reads configuration from a YAML file.
This class sets up the MLflow experiment and run, and exposes methods to log parameters,
metrics, and artifacts.
"""
def __init__(self, config_file="mlflow_config.yaml"):
# Load configuration from the YAML file.
if not os.path.exists(config_file):
raise FileNotFoundError(f"Config file '{config_file}' not found.")
with open(config_file, "r") as f:
self.config = yaml.safe_load(f)
# Set up configuration parameters
self.experiment_name = self.config.get("experiment_name", "Default_Experiment")
self.run_name = self.config.get("run_name", "Default_Run")
self.tracking_uri = self.config.get("tracking_uri", None)
self.metrics_to_track = self.config.get("metrics", [])
# Set tracking URI if provided
if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)
# Set the experiment
mlflow.set_experiment(self.experiment_name)
# Start the run
self.run = mlflow.start_run(run_name=self.run_name)
print(f"MLflow run started: Experiment='{self.experiment_name}', Run='{self.run_name}'")
def log_param(self, key, value):
"""Log a single parameter."""
mlflow.log_param(key, value)
def log_params(self, params: dict):
"""Log multiple parameters from a dictionary."""
mlflow.log_params(params)
def log_metric(self, key, value, step=None):
"""Log a single metric. Optionally include a step value."""
mlflow.log_metric(key, value, step=step)
def log_metrics(self, metrics: dict, step=None):
"""Log multiple metrics from a dictionary."""
for key, value in metrics.items():
self.log_metric(key, value, step=step)
def log_artifact(self, file_path, artifact_path=None):
"""Log an artifact (file) to MLflow."""
mlflow.log_artifact(file_path, artifact_path=artifact_path)
def end_run(self):
"""End the current MLflow run."""
mlflow.end_run()
print("MLflow run ended.")
# Example usage (can be removed or placed in a separate test script):
if __name__ == "__main__":
tracker = MLflowTracker("mlflow_config.yaml")
tracker.log_param("example_param", 123)
tracker.log_metric("example_metric", 0.95)
tracker.end_run()
|