Spaces:
Running
Running
""" | |
Generic helper functions to interact with the ultralytics yolo | |
models. | |
""" | |
from pathlib import Path | |
from ultralytics import YOLO | |
def load_pretrained_model(model_str: str) -> YOLO: | |
"""Loads the pretrained `model`""" | |
return YOLO(model_str) | |
DEFAULT_TRAIN_PARAMS = { | |
"batch": 16, | |
"epochs": 100, | |
"patience": 100, | |
"imgsz": 640, | |
"lr0": 0.01, | |
"lrf": 0.01, | |
"optimizer": "auto", | |
# data augmentation | |
"mixup": 0.0, | |
"close_mosaic": 10, | |
"degrees": 0.0, | |
"translate": 0.1, | |
"flipud": 0.0, | |
"fliplr": 0.5, | |
} | |
def train( | |
model: YOLO, | |
data_yaml_path: Path, | |
params: dict, | |
project: Path = Path("data/04_models/yolo/"), | |
experiment_name: str = "train", | |
): | |
"""Main function for running a train run. It saves the results | |
under `project / experiment_name`. | |
Args: | |
model (YOLO): result of `load_pretrained_model`. | |
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on | |
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters. | |
project (Path): root path to store the run artifacts and results. | |
experiment_name (str): name of the experiment, that is added to the project root path to store the run. | |
""" | |
assert data_yaml_path.exists(), f"data_yaml_path does not exist, {data_yaml_path}" | |
params = {**DEFAULT_TRAIN_PARAMS, **params} | |
model.train( | |
project=str(project), | |
name=experiment_name, | |
data=data_yaml_path.absolute(), | |
epochs=params["epochs"], | |
lr0=params["lr0"], | |
lrf=params["lrf"], | |
optimizer=params["optimizer"], | |
imgsz=params["imgsz"], | |
close_mosaic=params["close_mosaic"], | |
# Data Augmentation parameters | |
mixup=params["mixup"], | |
degrees=params["degrees"], | |
flipud=params["flipud"], | |
translate=params["translate"], | |
) | |