52Hz commited on
Commit
12908a1
·
1 Parent(s): 038056b

Update main_test_SRMNet.py

Browse files
Files changed (1) hide show
  1. main_test_SRMNet.py +16 -2
main_test_SRMNet.py CHANGED
@@ -7,8 +7,6 @@ from skimage import img_as_ubyte
7
  from tqdm import tqdm
8
  from natsort import natsorted
9
  from glob import glob
10
- from utils.image_utils import save_img
11
- from utils.model_utils import load_checkpoint
12
  import argparse
13
  from model_arch.SRMNet_SWFF import SRMNet_SWFF
14
  from model_arch.SRMNet import SRMNet
@@ -89,6 +87,22 @@ def define_model(args):
89
 
90
  return model
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  if __name__ == '__main__':
94
  main()
 
7
  from tqdm import tqdm
8
  from natsort import natsorted
9
  from glob import glob
 
 
10
  import argparse
11
  from model_arch.SRMNet_SWFF import SRMNet_SWFF
12
  from model_arch.SRMNet import SRMNet
 
87
 
88
  return model
89
 
90
+ def save_img(filepath, img):
91
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
92
+
93
+
94
+ def load_checkpoint(model, weights):
95
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
96
+ try:
97
+ model.load_state_dict(checkpoint["state_dict"])
98
+ except:
99
+ state_dict = checkpoint["state_dict"]
100
+ new_state_dict = OrderedDict()
101
+ for k, v in state_dict.items():
102
+ name = k[7:] # remove `module.`
103
+ new_state_dict[name] = v
104
+ model.load_state_dict(new_state_dict)
105
+
106
 
107
  if __name__ == '__main__':
108
  main()