Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import math | |
from torch.optim.lr_scheduler import _LRScheduler | |
__all__ = ['AnnealingLR'] | |
class AnnealingLR(_LRScheduler): | |
def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): | |
assert decay_mode in ['linear', 'cosine', 'none'] | |
self.optimizer = optimizer | |
self.base_lr = base_lr | |
self.warmup_steps = warmup_steps | |
self.total_steps = total_steps | |
self.decay_mode = decay_mode | |
self.min_lr = min_lr | |
self.current_step = last_step + 1 | |
self.step(self.current_step) | |
def get_lr(self): | |
if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: | |
return self.base_lr * self.current_step / self.warmup_steps | |
else: | |
ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
ratio = min(1.0, max(0.0, ratio)) | |
if self.decay_mode == 'linear': | |
return self.base_lr * (1 - ratio) | |
elif self.decay_mode == 'cosine': | |
return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 | |
else: | |
return self.base_lr | |
def step(self, current_step=None): | |
if current_step is None: | |
current_step = self.current_step + 1 | |
self.current_step = current_step | |
new_lr = max(self.min_lr, self.get_lr()) | |
if isinstance(self.optimizer, list): | |
for o in self.optimizer: | |
for group in o.param_groups: | |
group['lr'] = new_lr | |
else: | |
for group in self.optimizer.param_groups: | |
group['lr'] = new_lr | |
def state_dict(self): | |
return { | |
'base_lr': self.base_lr, | |
'warmup_steps': self.warmup_steps, | |
'total_steps': self.total_steps, | |
'decay_mode': self.decay_mode, | |
'current_step': self.current_step} | |
def load_state_dict(self, state_dict): | |
self.base_lr = state_dict['base_lr'] | |
self.warmup_steps = state_dict['warmup_steps'] | |
self.total_steps = state_dict['total_steps'] | |
self.decay_mode = state_dict['decay_mode'] | |
self.current_step = state_dict['current_step'] | |