Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,566 Bytes
3e648fb d86981d 3e648fb 729b4cc d86981d 3e648fb 5b15317 3e648fb 729b4cc 6319557 f625260 8ced22f 3e648fb d86981d 2d65752 3e648fb 5b15317 3e648fb 8ced22f 3e648fb 8ced22f 3e648fb 8ced22f 3e648fb be1e53a 2d65752 a7404ed 2d65752 3e648fb c9dce21 fb17c25 c9dce21 da0c19f c9dce21 3e648fb c9dce21 3e648fb c139b89 fb17c25 3e648fb d86981d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import numpy as np
import torch
import gradio as gr
from PIL import Image
from net.CIDNet import CIDNet
import torchvision.transforms as transforms
import torch.nn.functional as F
import safetensors.torch as sf
import imquality.brisque as brisque
from loss.niqe_utils import *
import spaces
from huggingface_hub import hf_hub_download
import json
def from_pretrained(cls, pretrained_model_name_or_path: str):
model_id = str(pretrained_model_name_or_path)
config_file = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
config = None
if config_file is not None:
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
model_file = hf_hub_download(repo_id=model_id, filename="model.safetensors", repo_type="model")
# instance = sf.load_model(cls, model_file, strict=False)
state_dict = sf.load_file(model_file)
cls.load_state_dict(state_dict, strict=False)
eval_net = CIDNet().cuda()
eval_net.trans.gated = True
eval_net.trans.gated2 = True
@spaces.GPU(duration=120)
def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
if model_path is None:
return input_img,"Please choose a model weights."
torch.set_grad_enabled(False)
from_pretrained(eval_net,"Fediory/HVI-CIDNet-"+model_path)
# eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
eval_net.eval()
pil2tensor = transforms.Compose([transforms.ToTensor()])
input = pil2tensor(input_img)
factor = 8
h, w = input.shape[1], input.shape[2]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
with torch.no_grad():
eval_net.trans.alpha_s = alpha_s
eval_net.trans.alpha = alpha_i
output = eval_net(input.cuda()**gamma)
output = torch.clamp(output,0,1).cuda()
output = output[:, :, :h, :w]
enhanced_img = transforms.ToPILImage()(output.squeeze(0))
if score == 'Yes':
im1 = np.array(enhanced_img)
score_niqe = calculate_niqe(im1)
return enhanced_img,score_niqe
else:
return enhanced_img,0
directory = "weights"
pth_files = [
'Generalization',
'Sony-Total-Dark',
'LOL-Blur',
'SICE',
'LOLv2-real-bestSSIM',
'LOLv2-real-bestPSNR',
'LOLv2-syn-wperc',
'LOLv2-syn-woperc',
'LOLv1-wperc',
'LOLv1-woperc'
]
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Low-light Image", type="pil"),
gr.Radio(choices=['Yes','No'],label="Image Score",info="Calculate NIQE, default is \"No\"."),
gr.Radio(choices=pth_files,label="Model Weights",info="Choose your model. The best models are \"SICE\" and \"Generalization\"."),
gr.Slider(0.1,5,label="gamma curve",step=0.01,value=1.0, info="Lower is lighter, best range is [0.5,2.5]."),
gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."),
gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is lighter.")
],
outputs=[
gr.Image(label="Result", type="pil"),
gr.Textbox(label="NIQE",info="Lower is better.")
],
title="Light Amplification",
description="The demo of paper \"HVI: A New Color Space for Low-light Image Enhancement\"",
allow_flagging="never"
)
interface.launch(share=False)
|