import argparse import os import torch def parse_args(): parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2") parser.add_argument("source_model", default="", type=str, help="Source model") parser.add_argument("output_model", default="", type=str, help="Output model") return parser.parse_args() def main(): args = parse_args() if os.path.splitext(args.source_model)[-1] != ".pth": raise ValueError("You should save weights as pth file") source_weights = torch.load( args.source_model, map_location=torch.device('cpu'))["model"] converted_weights = {} keys = list(source_weights.keys()) prefix = 'backbone.bottom_up.' for key in keys: converted_weights[prefix+key] = source_weights[key] torch.save(converted_weights, args.output_model) if __name__ == "__main__": main()