training_bench / algorithms /base_optimizer.py
rider-provider-777's picture
Upload 5 files
e016a55 verified
raw
history blame
520 Bytes
from abc import ABC, abstractmethod
class BaseOptimizer(ABC):
"""Abstract base class for all training algorithms."""
def __init__(self, model, config):
self.model = model
self.config = config
self.accelerator = None
def set_accelerator(self, accelerator):
self.accelerator = accelerator
@abstractmethod
def step(self, inputs, labels):
"""Performs a single training step; must return a Python float loss."""
raise NotImplementedError