Prajwal-r-k commited on
Commit
e3eb054
·
verified ·
1 Parent(s): 75604f3

Upload demo.py

Browse files
Files changed (1) hide show
  1. NAFNet/demo.py +62 -0
NAFNet/demo.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5
+ # Copyright 2018-2020 BasicSR Authors
6
+ # ------------------------------------------------------------------------
7
+ import torch
8
+
9
+ # from basicsr.data import create_dataloader, create_dataset
10
+ from basicsr.models import create_model
11
+ from basicsr.train import parse_options
12
+ from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite
13
+
14
+ # from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
15
+ # make_exp_dirs)
16
+ # from basicsr.utils.options import dict2str
17
+
18
+ def main():
19
+ # parse options, set distributed setting, set ramdom seed
20
+ opt = parse_options(is_train=False)
21
+ opt['num_gpu'] = torch.cuda.device_count()
22
+
23
+ img_path = opt['img_path'].get('input_img')
24
+ output_path = opt['img_path'].get('output_img')
25
+
26
+
27
+ ## 1. read image
28
+ file_client = FileClient('disk')
29
+
30
+ img_bytes = file_client.get(img_path, None)
31
+ try:
32
+ img = imfrombytes(img_bytes, float32=True)
33
+ except:
34
+ raise Exception("path {} not working".format(img_path))
35
+
36
+ img = img2tensor(img, bgr2rgb=True, float32=True)
37
+
38
+
39
+
40
+ ## 2. run inference
41
+ opt['dist'] = False
42
+ model = create_model(opt)
43
+
44
+ model.feed_data(data={'lq': img.unsqueeze(dim=0)})
45
+
46
+ if model.opt['val'].get('grids', False):
47
+ model.grids()
48
+
49
+ model.test()
50
+
51
+ if model.opt['val'].get('grids', False):
52
+ model.grids_inverse()
53
+
54
+ visuals = model.get_current_visuals()
55
+ sr_img = tensor2img([visuals['result']])
56
+ imwrite(sr_img, output_path)
57
+
58
+ print(f'inference {img_path} .. finished. saved to {output_path}')
59
+
60
+ if __name__ == '__main__':
61
+ main()
62
+