File size: 3,566 Bytes
3e648fb
 
 
 
 
 
 
d86981d
3e648fb
 
729b4cc
d86981d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e648fb
5b15317
3e648fb
 
 
729b4cc
6319557
f625260
8ced22f
3e648fb
d86981d
2d65752
3e648fb
 
 
 
 
 
 
 
 
 
 
 
 
5b15317
 
3e648fb
 
 
8ced22f
3e648fb
8ced22f
3e648fb
8ced22f
3e648fb
 
be1e53a
2d65752
a7404ed
 
 
 
 
 
 
 
 
 
2d65752
 
3e648fb
 
 
 
 
c9dce21
fb17c25
c9dce21
da0c19f
c9dce21
3e648fb
 
 
c9dce21
3e648fb
c139b89
fb17c25
3e648fb
 
 
d86981d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import torch
import gradio as gr
from PIL import Image
from net.CIDNet import CIDNet
import torchvision.transforms as transforms
import torch.nn.functional as F
import safetensors.torch as sf
import imquality.brisque as brisque
from loss.niqe_utils import *
import spaces
from huggingface_hub import hf_hub_download
import json

def from_pretrained(cls, pretrained_model_name_or_path: str):
    model_id = str(pretrained_model_name_or_path)

    config_file = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
    config = None
    if config_file is not None:
        with open(config_file, "r", encoding="utf-8") as f:
            config = json.load(f)


    model_file = hf_hub_download(repo_id=model_id, filename="model.safetensors", repo_type="model")
    # instance = sf.load_model(cls, model_file, strict=False)
    state_dict  = sf.load_file(model_file)
    cls.load_state_dict(state_dict, strict=False) 

eval_net = CIDNet().cuda()
eval_net.trans.gated = True
eval_net.trans.gated2 = True

@spaces.GPU(duration=120)
def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
    if model_path is None:
        return input_img,"Please choose a model weights."
    torch.set_grad_enabled(False)
    from_pretrained(eval_net,"Fediory/HVI-CIDNet-"+model_path)
    # eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
    eval_net.eval()
    
    pil2tensor = transforms.Compose([transforms.ToTensor()])
    input = pil2tensor(input_img)
    factor = 8
    h, w = input.shape[1], input.shape[2]
    H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
    padh = H - h if h % factor != 0 else 0
    padw = W - w if w % factor != 0 else 0
    input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
    with torch.no_grad():
        eval_net.trans.alpha_s = alpha_s
        eval_net.trans.alpha = alpha_i
        output = eval_net(input.cuda()**gamma)
    output = torch.clamp(output,0,1).cuda()
    output = output[:, :, :h, :w]
    enhanced_img = transforms.ToPILImage()(output.squeeze(0))
    if score == 'Yes':
        im1 = np.array(enhanced_img)
        score_niqe = calculate_niqe(im1)
        return enhanced_img,score_niqe
    else:
        return enhanced_img,0


directory = "weights"
pth_files = [
    'Generalization',
    'Sony-Total-Dark',
    'LOL-Blur',
    'SICE',
    'LOLv2-real-bestSSIM',
    'LOLv2-real-bestPSNR',
    'LOLv2-syn-wperc',
    'LOLv2-syn-woperc',
    'LOLv1-wperc',
    'LOLv1-woperc'
]


interface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(label="Low-light Image", type="pil"),
        gr.Radio(choices=['Yes','No'],label="Image Score",info="Calculate NIQE, default is \"No\"."),
        gr.Radio(choices=pth_files,label="Model Weights",info="Choose your model. The best models are \"SICE\" and \"Generalization\"."),
        gr.Slider(0.1,5,label="gamma curve",step=0.01,value=1.0, info="Lower is lighter, best range is [0.5,2.5]."),
        gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."),
        gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is lighter.")
    ],
    outputs=[
        gr.Image(label="Result", type="pil"),
        gr.Textbox(label="NIQE",info="Lower is better.")
    ],
    title="Light Amplification",
    description="The demo of paper \"HVI: A New Color Space for Low-light Image Enhancement\"",
    allow_flagging="never"
)

interface.launch(share=False)