File size: 2,011 Bytes
641857b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Generic helper functions to interact with the ultralytics yolo
models.
"""

from pathlib import Path

from ultralytics import YOLO


def load_pretrained_model(model_str: str) -> YOLO:
    """Loads the pretrained `model`"""
    return YOLO(model_str)


DEFAULT_TRAIN_PARAMS = {
    "batch": 16,
    "epochs": 100,
    "patience": 100,
    "imgsz": 640,
    "lr0": 0.01,
    "lrf": 0.01,
    "optimizer": "auto",
    # data augmentation
    "mixup": 0.0,
    "close_mosaic": 10,
    "degrees": 0.0,
    "translate": 0.1,
    "flipud": 0.0,
    "fliplr": 0.5,
}


def train(
    model: YOLO,
    data_yaml_path: Path,
    params: dict,
    project: Path = Path("data/04_models/yolo/"),
    experiment_name: str = "train",
):
    """Main function for running a train run. It saves the results
    under `project / experiment_name`.

    Args:
        model (YOLO): result of `load_pretrained_model`.
        data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
        params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
        project (Path): root path to store the run artifacts and results.
        experiment_name (str): name of the experiment, that is added to the project root path to store the run.
    """
    assert data_yaml_path.exists(), f"data_yaml_path does not exist, {data_yaml_path}"
    params = {**DEFAULT_TRAIN_PARAMS, **params}
    model.train(
        project=str(project),
        name=experiment_name,
        data=data_yaml_path.absolute(),
        epochs=params["epochs"],
        lr0=params["lr0"],
        lrf=params["lrf"],
        optimizer=params["optimizer"],
        imgsz=params["imgsz"],
        close_mosaic=params["close_mosaic"],
        # Data Augmentation parameters
        mixup=params["mixup"],
        degrees=params["degrees"],
        flipud=params["flipud"],
        translate=params["translate"],
    )