Spaces:
Sleeping
Sleeping
import mlflow | |
import yaml | |
import os | |
class MLflowTracker: | |
""" | |
A reusable MLflow tracking class that reads configuration from a YAML file. | |
This class sets up the MLflow experiment and run, and exposes methods to log parameters, | |
metrics, and artifacts. | |
""" | |
def __init__(self, config_file="mlflow_config.yaml"): | |
# Load configuration from the YAML file. | |
if not os.path.exists(config_file): | |
raise FileNotFoundError(f"Config file '{config_file}' not found.") | |
with open(config_file, "r") as f: | |
self.config = yaml.safe_load(f) | |
# Set up configuration parameters | |
self.experiment_name = self.config.get("experiment_name", "Default_Experiment") | |
self.run_name = self.config.get("run_name", "Default_Run") | |
self.tracking_uri = self.config.get("tracking_uri", None) | |
self.metrics_to_track = self.config.get("metrics", []) | |
# Set tracking URI if provided | |
if self.tracking_uri: | |
mlflow.set_tracking_uri(self.tracking_uri) | |
# Set the experiment | |
mlflow.set_experiment(self.experiment_name) | |
# Start the run | |
self.run = mlflow.start_run(run_name=self.run_name) | |
print(f"MLflow run started: Experiment='{self.experiment_name}', Run='{self.run_name}'") | |
def log_param(self, key, value): | |
"""Log a single parameter.""" | |
mlflow.log_param(key, value) | |
def log_params(self, params: dict): | |
"""Log multiple parameters from a dictionary.""" | |
mlflow.log_params(params) | |
def log_metric(self, key, value, step=None): | |
"""Log a single metric. Optionally include a step value.""" | |
mlflow.log_metric(key, value, step=step) | |
def log_metrics(self, metrics: dict, step=None): | |
"""Log multiple metrics from a dictionary.""" | |
for key, value in metrics.items(): | |
self.log_metric(key, value, step=step) | |
def log_artifact(self, file_path, artifact_path=None): | |
"""Log an artifact (file) to MLflow.""" | |
mlflow.log_artifact(file_path, artifact_path=artifact_path) | |
def end_run(self): | |
"""End the current MLflow run.""" | |
mlflow.end_run() | |
print("MLflow run ended.") | |
# Example usage (can be removed or placed in a separate test script): | |
if __name__ == "__main__": | |
tracker = MLflowTracker("mlflow_config.yaml") | |
tracker.log_param("example_param", 123) | |
tracker.log_metric("example_metric", 0.95) | |
tracker.end_run() | |