trout-reID / yolo.py
achouffe's picture
feat: initial commit
641857b verified
"""
Generic helper functions to interact with the ultralytics yolo
models.
"""
from pathlib import Path
from ultralytics import YOLO
def load_pretrained_model(model_str: str) -> YOLO:
"""Loads the pretrained `model`"""
return YOLO(model_str)
DEFAULT_TRAIN_PARAMS = {
"batch": 16,
"epochs": 100,
"patience": 100,
"imgsz": 640,
"lr0": 0.01,
"lrf": 0.01,
"optimizer": "auto",
# data augmentation
"mixup": 0.0,
"close_mosaic": 10,
"degrees": 0.0,
"translate": 0.1,
"flipud": 0.0,
"fliplr": 0.5,
}
def train(
model: YOLO,
data_yaml_path: Path,
params: dict,
project: Path = Path("data/04_models/yolo/"),
experiment_name: str = "train",
):
"""Main function for running a train run. It saves the results
under `project / experiment_name`.
Args:
model (YOLO): result of `load_pretrained_model`.
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
project (Path): root path to store the run artifacts and results.
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
"""
assert data_yaml_path.exists(), f"data_yaml_path does not exist, {data_yaml_path}"
params = {**DEFAULT_TRAIN_PARAMS, **params}
model.train(
project=str(project),
name=experiment_name,
data=data_yaml_path.absolute(),
epochs=params["epochs"],
lr0=params["lr0"],
lrf=params["lrf"],
optimizer=params["optimizer"],
imgsz=params["imgsz"],
close_mosaic=params["close_mosaic"],
# Data Augmentation parameters
mixup=params["mixup"],
degrees=params["degrees"],
flipud=params["flipud"],
translate=params["translate"],
)