File size: 262 Bytes
3f8f152
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch

OPTIMIZERS_POOL = {
    'sgd': torch.optim.SGD,
}

def get_optimizer(model_params, optimizer_config):
    name, params = list(optimizer_config.items())[0]
    optimizer = OPTIMIZERS_POOL[name](model_params, **params)
    return optimizer