Spaces:
Runtime error
Runtime error
import torch | |
OPTIMIZERS_POOL = { | |
'sgd': torch.optim.SGD, | |
} | |
def get_optimizer(model_params, optimizer_config): | |
name, params = list(optimizer_config.items())[0] | |
optimizer = OPTIMIZERS_POOL[name](model_params, **params) | |
return optimizer | |