Spaces:
Runtime error
Runtime error
| """ | |
| @Date: 2021/07/18 | |
| @description: | |
| """ | |
| from torch import optim as optim | |
| def build_optimizer(config, model, logger): | |
| name = config.TRAIN.OPTIMIZER.NAME.lower() | |
| optimizer = None | |
| if name == 'sgd': | |
| optimizer = optim.SGD(model.parameters(), momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, | |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
| elif name == 'adamw': | |
| optimizer = optim.AdamW(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, | |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
| elif name == 'adam': | |
| optimizer = optim.Adam(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, | |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
| logger.info(f"Build optimizer: {name}, lr:{config.TRAIN.BASE_LR}") | |
| return optimizer | |