CJK-Text-Detection / utils /schedulers.py
jmliu's picture
Add app
0742dfe
raw
history blame
2.01 kB
from torch.optim.lr_scheduler import _LRScheduler
class ConstantLR(_LRScheduler):
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr for base_lr in self.base_lrs]
class PolynomialLR(_LRScheduler):
def __init__(self, optimizer, max_iter, power=0.9, last_epoch=-1):
self.max_iter = max_iter
self.power = power
super(PolynomialLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
return [base_lr * factor for base_lr in self.base_lrs]
class WarmUpLR(_LRScheduler):
def __init__(
self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1
):
self.mode = mode
self.scheduler = scheduler
self.warmup_iters = warmup_iters
self.gamma = gamma
super(WarmUpLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
cold_lrs = self.scheduler.get_lr()
if self.last_epoch < self.warmup_iters:
if self.mode == "linear":
alpha = self.last_epoch / float(self.warmup_iters)
factor = self.gamma * (1 - alpha) + alpha
elif self.mode == "constant":
factor = self.gamma
else:
raise KeyError("WarmUp type {} not implemented".format(self.mode))
return [factor * base_lr for base_lr in cold_lrs]
return cold_lrs
if __name__ == '__main__':
import torch
from torchvision.models import resnet18
max_iter = 600 * 125
model = resnet18()
op = torch.optim.SGD(model.parameters(),0.001)
sc = PolynomialLR(op,max_iter)
lr = []
for i in range(max_iter):
sc.step()
print(i,sc.last_epoch,sc.get_lr()[0])
lr.append(sc.get_lr()[0])
from matplotlib import pyplot as plt
plt.plot(list(range(max_iter)),lr)
plt.show()