SwinTExCo-cpu / src /scheduler.py
duongttr's picture
Upload app cpu version
62b9b3d
from torch.optim.lr_scheduler import _LRScheduler
class PolynomialLR(_LRScheduler):
def __init__(
self,
optimizer,
step_size,
iter_warmup,
iter_max,
power,
min_lr=0,
last_epoch=-1,
):
self.step_size = step_size
self.iter_warmup = int(iter_warmup)
self.iter_max = int(iter_max)
self.power = power
self.min_lr = min_lr
super(PolynomialLR, self).__init__(optimizer, last_epoch)
def polynomial_decay(self, lr):
iter_cur = float(self.last_epoch)
if iter_cur < self.iter_warmup:
coef = iter_cur / self.iter_warmup
coef *= (1 - self.iter_warmup / self.iter_max) ** self.power
else:
coef = (1 - iter_cur / self.iter_max) ** self.power
return (lr - self.min_lr) * coef + self.min_lr
def get_lr(self):
if (
(self.last_epoch == 0)
or (self.last_epoch % self.step_size != 0)
or (self.last_epoch > self.iter_max)
):
return [group["lr"] for group in self.optimizer.param_groups]
return [self.polynomial_decay(lr) for lr in self.base_lrs]
def step_update(self, num_updates):
self.step()