Fediory commited on
Commit
6319557
·
1 Parent(s): da0c19f

fix: weights

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -13,7 +13,9 @@ eval_net = CIDNet()
13
  eval_net.trans.gated = True
14
  eval_net.trans.gated2 = True
15
 
16
- def process_image(input_img,score,model_path,gamma,alpha_s=1.0,alpha_i=1.0):
 
 
17
  torch.set_grad_enabled(False)
18
  eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
19
  eval_net.eval()
@@ -65,7 +67,7 @@ interface = gr.Interface(
65
  inputs=[
66
  gr.Image(label="Low-light Image", type="pil"),
67
  gr.Radio(choices=['Yes','No'],label="Image Score"),
68
- gr.Radio(choices=pth_files2,label="Model Path",info="Choose your model. The best model is \"generalization.pth\"."),
69
  gr.Slider(0.1,10,label="gamma curve",step=0.01,value=1.0, info="Best range is [0.5,2.5]."),
70
  gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."),
71
  gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is more lighted.")
 
13
  eval_net.trans.gated = True
14
  eval_net.trans.gated2 = True
15
 
16
+ def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
17
+ if len(model_path) == 0:
18
+ return input_img,"Please choose a model weights.","Please choose a model weights."
19
  torch.set_grad_enabled(False)
20
  eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
21
  eval_net.eval()
 
67
  inputs=[
68
  gr.Image(label="Low-light Image", type="pil"),
69
  gr.Radio(choices=['Yes','No'],label="Image Score"),
70
+ gr.Radio(choices=pth_files2,label="Model Weights",info="Choose your model. The best models are \"SICE.pth\" and \"generalization.pth\"."),
71
  gr.Slider(0.1,10,label="gamma curve",step=0.01,value=1.0, info="Best range is [0.5,2.5]."),
72
  gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."),
73
  gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is more lighted.")