File size: 2,540 Bytes
baa5edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import mlflow
import yaml
import os

class MLflowTracker:
    """
    A reusable MLflow tracking class that reads configuration from a YAML file.
    This class sets up the MLflow experiment and run, and exposes methods to log parameters,
    metrics, and artifacts.
    """
    def __init__(self, config_file="mlflow_config.yaml"):
        # Load configuration from the YAML file.
        if not os.path.exists(config_file):
            raise FileNotFoundError(f"Config file '{config_file}' not found.")
        with open(config_file, "r") as f:
            self.config = yaml.safe_load(f)
        
        # Set up configuration parameters
        self.experiment_name = self.config.get("experiment_name", "Default_Experiment")
        self.run_name = self.config.get("run_name", "Default_Run")
        self.tracking_uri = self.config.get("tracking_uri", None)
        self.metrics_to_track = self.config.get("metrics", [])
        
        # Set tracking URI if provided
        if self.tracking_uri:
            mlflow.set_tracking_uri(self.tracking_uri)
        
        # Set the experiment
        mlflow.set_experiment(self.experiment_name)
        
        # Start the run
        self.run = mlflow.start_run(run_name=self.run_name)
        print(f"MLflow run started: Experiment='{self.experiment_name}', Run='{self.run_name}'")
    
    def log_param(self, key, value):
        """Log a single parameter."""
        mlflow.log_param(key, value)
    
    def log_params(self, params: dict):
        """Log multiple parameters from a dictionary."""
        mlflow.log_params(params)
    
    def log_metric(self, key, value, step=None):
        """Log a single metric. Optionally include a step value."""
        mlflow.log_metric(key, value, step=step)
    
    def log_metrics(self, metrics: dict, step=None):
        """Log multiple metrics from a dictionary."""
        for key, value in metrics.items():
            self.log_metric(key, value, step=step)
    
    def log_artifact(self, file_path, artifact_path=None):
        """Log an artifact (file) to MLflow."""
        mlflow.log_artifact(file_path, artifact_path=artifact_path)
    
    def end_run(self):
        """End the current MLflow run."""
        mlflow.end_run()
        print("MLflow run ended.")

# Example usage (can be removed or placed in a separate test script):
if __name__ == "__main__":
    tracker = MLflowTracker("mlflow_config.yaml")
    tracker.log_param("example_param", 123)
    tracker.log_metric("example_metric", 0.95)
    tracker.end_run()