File size: 508 Bytes
f7009b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import numpy as np
import torch
import sys
import os
file = sys.argv[1]
model = torch.load(file, map_location="cpu")
if 'meta' in model.keys():
print("this file need not to convert.")
exit(0)
else: # this is a raw checkpoint
meta_file = os.path.join(os.path.dirname(__file__), "Segmentation/example.pth")
meta_data = torch.load(meta_file, map_location="cpu")['meta']
model = {'meta': meta_data, "state_dict": model}
torch.save(model, file)
print("converted to test-able file.")
|