Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Optimizer | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import torch | |
from pointcept.utils.logger import get_root_logger | |
from pointcept.utils.registry import Registry | |
OPTIMIZERS = Registry("optimizers") | |
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") | |
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") | |
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") | |
def build_optimizer(cfg, model, param_dicts=None): | |
if param_dicts is None: | |
cfg.params = model.parameters() | |
else: | |
cfg.params = [dict(names=[], params=[], lr=cfg.lr)] | |
for i in range(len(param_dicts)): | |
param_group = dict(names=[], params=[]) | |
if "lr" in param_dicts[i].keys(): | |
param_group["lr"] = param_dicts[i].lr | |
if "momentum" in param_dicts[i].keys(): | |
param_group["momentum"] = param_dicts[i].momentum | |
if "weight_decay" in param_dicts[i].keys(): | |
param_group["weight_decay"] = param_dicts[i].weight_decay | |
cfg.params.append(param_group) | |
for n, p in model.named_parameters(): | |
flag = False | |
for i in range(len(param_dicts)): | |
if param_dicts[i].keyword in n: | |
cfg.params[i + 1]["names"].append(n) | |
cfg.params[i + 1]["params"].append(p) | |
flag = True | |
break | |
if not flag: | |
cfg.params[0]["names"].append(n) | |
cfg.params[0]["params"].append(p) | |
logger = get_root_logger() | |
for i in range(len(cfg.params)): | |
param_names = cfg.params[i].pop("names") | |
message = "" | |
for key in cfg.params[i].keys(): | |
if key != "params": | |
message += f" {key}: {cfg.params[i][key]};" | |
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") | |
return OPTIMIZERS.build(cfg=cfg) | |