Spaces:
Runtime error
Runtime error
File size: 2,615 Bytes
c3d0293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, [email protected]
import sys
from typing import NewType, List, Dict
import torch
import torch.optim as optim
from loguru import logger
from torchtrustncg import TrustRegion
Tensor = NewType('Tensor', torch.Tensor)
def build_optimizer(parameters: List[Tensor],
optim_cfg: Dict
) -> Dict:
''' Creates the optimizer
'''
optim_type = optim_cfg.get('type', 'sgd')
logger.info(f'Building: {optim_type.title()}')
num_params = len(parameters)
parameters = list(filter(lambda x: x.requires_grad, parameters))
if num_params != len(parameters):
logger.info(f'Some parameters have requires_grad off')
if optim_type == 'adam':
optimizer = optim.Adam(parameters, **optim_cfg.get('adam', {}))
create_graph = False
elif optim_type == 'lbfgs' or optim_type == 'lbfgsls':
optimizer = optim.LBFGS(parameters, **optim_cfg.get('lbfgs', {}))
create_graph = False
elif optim_type == 'trust_ncg' or optim_type == 'trust-ncg':
optimizer = TrustRegion(
parameters, **optim_cfg.get('trust_ncg', {}))
create_graph = True
elif optim_type == 'rmsprop':
optimizer = optim.RMSprop(parameters, **optim_cfg.get('rmsprop', {}))
create_graph = False
elif optim_type == 'sgd':
optimizer = optim.SGD(parameters, **optim_cfg.get('sgd', {}))
create_graph = False
else:
raise ValueError(f'Optimizer {optim_type} not supported!')
return {'optimizer': optimizer, 'create_graph': create_graph}
def build_scheduler(optimizer, sched_type='exp',
lr_lambda=0.1, **kwargs):
if lr_lambda <= 0.0:
return None
if sched_type == 'exp':
return optim.lr_scheduler.ExponentialLR(optimizer, lr_lambda)
else:
raise ValueError('Unknown learning rate' +
' scheduler: '.format(sched_type))
|