Spaces:
Build error
Build error
from bisect import bisect_right | |
from torch.optim.lr_scheduler import _LRScheduler | |
class LRStepScheduler(_LRScheduler): | |
def __init__(self, optimizer, steps, last_epoch=-1): | |
self.lr_steps = steps | |
super().__init__(optimizer, last_epoch) | |
def get_lr(self): | |
pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0) | |
return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs] | |
class PolyLR(_LRScheduler): | |
"""Sets the learning rate of each parameter group according to poly learning rate policy | |
""" | |
def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1): | |
self.max_iter = max_iter | |
self.power = power | |
super(PolyLR, self).__init__(optimizer, last_epoch) | |
def get_lr(self): | |
self.last_epoch = (self.last_epoch + 1) % self.max_iter | |
return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs] | |
class ExponentialLRScheduler(_LRScheduler): | |
"""Decays the learning rate of each parameter group by gamma every epoch. | |
When last_epoch=-1, sets initial lr as lr. | |
Args: | |
optimizer (Optimizer): Wrapped optimizer. | |
gamma (float): Multiplicative factor of learning rate decay. | |
last_epoch (int): The index of last epoch. Default: -1. | |
""" | |
def __init__(self, optimizer, gamma, last_epoch=-1): | |
self.gamma = gamma | |
super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch) | |
def get_lr(self): | |
if self.last_epoch <= 0: | |
return self.base_lrs | |
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] | |