Spaces:
Configuration error
Configuration error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from collections import OrderedDict | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| def _rename_HRNet_weights(weights): | |
| # We detect and rename HRNet weights for DensePose. 1956 and 1716 are values that are | |
| # common to all HRNet pretrained weights, and should be enough to accurately identify them | |
| if ( | |
| len(weights["model"].keys()) == 1956 | |
| and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716 | |
| ): | |
| hrnet_weights = OrderedDict() | |
| for k in weights["model"].keys(): | |
| hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k] | |
| return {"model": hrnet_weights} | |
| else: | |
| return weights | |
| class DensePoseCheckpointer(DetectionCheckpointer): | |
| """ | |
| Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights | |
| """ | |
| def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | |
| super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables) | |
| def _load_file(self, filename: str) -> object: | |
| """ | |
| Adding hrnet support | |
| """ | |
| weights = super()._load_file(filename) | |
| return _rename_HRNet_weights(weights) | |