from abc import ABC, abstractmethod class BaseOptimizer(ABC): """Abstract base class for all training algorithms.""" def __init__(self, model, config): self.model = model self.config = config self.accelerator = None def set_accelerator(self, accelerator): self.accelerator = accelerator @abstractmethod def step(self, inputs, labels): """Performs a single training step; must return a Python float loss.""" raise NotImplementedError