Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from net.CIDNet import CIDNet | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
import os | |
import imquality.brisque as brisque | |
from loss.niqe_utils import * | |
import spaces | |
import huggingface_hub | |
eval_net = CIDNet().cuda() | |
eval_net.trans.gated = True | |
eval_net.trans.gated2 = True | |
def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0): | |
if model_path is None: | |
return input_img,"Please choose a model weights." | |
torch.set_grad_enabled(False) | |
eval_net.from_pretrained(model_path) | |
# eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage)) | |
eval_net.eval() | |
pil2tensor = transforms.Compose([transforms.ToTensor()]) | |
input = pil2tensor(input_img) | |
factor = 8 | |
h, w = input.shape[1], input.shape[2] | |
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor | |
padh = H - h if h % factor != 0 else 0 | |
padw = W - w if w % factor != 0 else 0 | |
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect') | |
with torch.no_grad(): | |
eval_net.trans.alpha_s = alpha_s | |
eval_net.trans.alpha = alpha_i | |
output = eval_net(input.cuda()**gamma) | |
output = torch.clamp(output,0,1).cuda() | |
output = output[:, :, :h, :w] | |
enhanced_img = transforms.ToPILImage()(output.squeeze(0)) | |
if score == 'Yes': | |
im1 = np.array(enhanced_img) | |
score_niqe = calculate_niqe(im1) | |
return enhanced_img,score_niqe | |
else: | |
return enhanced_img,0 | |
directory = "weights" | |
pth_files = [ | |
'HVI-CIDNet-Generalization', | |
'HVI-CIDNet-Sony-Total-Dark', | |
'HVI-CIDNet-LOL-Blur', | |
'HVI-CIDNet-SICE', | |
'HVI-CIDNet-LOLv2-real-bestSSIM', | |
'HVI-CIDNet-LOLv2-real-bestPSNR', | |
'HVI-CIDNet-LOLv2-syn-wperc', | |
'HVI-CIDNet-LOLv2-syn-woperc', | |
'HVI-CIDNet-LOLv1-wperc', | |
'HVI-CIDNet-LOLv1-woperc' | |
] | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(label="Low-light Image", type="pil"), | |
gr.Radio(choices=['Yes','No'],label="Image Score",info="Calculate NIQE, default is \"No\"."), | |
gr.Radio(choices=pth_files2,label="Model Weights",info="Choose your model. The best models are \"SICE.pth\" and \"generalization.pth\"."), | |
gr.Slider(0.1,5,label="gamma curve",step=0.01,value=1.0, info="Lower is lighter, best range is [0.5,2.5]."), | |
gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."), | |
gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is lighter.") | |
], | |
outputs=[ | |
gr.Image(label="Result", type="pil"), | |
gr.Textbox(label="NIQE",info="Lower is better.") | |
], | |
title="HVI-CIDNet (Low-Light Image Enhancement)", | |
description="The demo of paper \"You Only Need One Color Space: An Efficient Network for Low-light Image Enhancement\"", | |
allow_flagging="never" | |
) | |
interface.launch(share=True) | |