Fediory commited on
Commit
5b15317
·
1 Parent(s): c97d56f
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import imquality.brisque as brisque
10
  from loss.niqe_utils import *
11
 
12
- eval_net = CIDNet()
13
  eval_net.trans.gated = True
14
  eval_net.trans.gated2 = True
15
 
@@ -31,8 +31,8 @@ def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
31
  with torch.no_grad():
32
  eval_net.trans.alpha_s = alpha_s
33
  eval_net.trans.alpha = alpha_i
34
- output = eval_net(input**gamma)
35
- output = torch.clamp(output,0,1)
36
  output = output[:, :, :h, :w]
37
  enhanced_img = transforms.ToPILImage()(output.squeeze(0))
38
  if score == 'Yes':
 
9
  import imquality.brisque as brisque
10
  from loss.niqe_utils import *
11
 
12
+ eval_net = CIDNet().cuda()
13
  eval_net.trans.gated = True
14
  eval_net.trans.gated2 = True
15
 
 
31
  with torch.no_grad():
32
  eval_net.trans.alpha_s = alpha_s
33
  eval_net.trans.alpha = alpha_i
34
+ output = eval_net(input.cuda()**gamma)
35
+ output = torch.clamp(output,0,1).cuda()
36
  output = output[:, :, :h, :w]
37
  enhanced_img = transforms.ToPILImage()(output.squeeze(0))
38
  if score == 'Yes':