import numpy as np import torch import gradio as gr from PIL import Image from net.CIDNet import CIDNet import torchvision.transforms as transforms import torch.nn.functional as F import safetensors.torch as sf import imquality.brisque as brisque from loss.niqe_utils import * import spaces from huggingface_hub import hf_hub_download import json def from_pretrained(cls, pretrained_model_name_or_path: str): model_id = str(pretrained_model_name_or_path) config_file = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model") config = None if config_file is not None: with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) model_file = hf_hub_download(repo_id=model_id, filename="model.safetensors", repo_type="model") # instance = sf.load_model(cls, model_file, strict=False) state_dict = sf.load_file(model_file) cls.load_state_dict(state_dict, strict=False) eval_net = CIDNet().cuda() eval_net.trans.gated = True eval_net.trans.gated2 = True @spaces.GPU(duration=120) def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0): if model_path is None: return input_img,"Please choose a model weights." torch.set_grad_enabled(False) from_pretrained(eval_net,"Fediory/HVI-CIDNet-"+model_path) # eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage)) eval_net.eval() pil2tensor = transforms.Compose([transforms.ToTensor()]) input = pil2tensor(input_img) factor = 8 h, w = input.shape[1], input.shape[2] H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor padh = H - h if h % factor != 0 else 0 padw = W - w if w % factor != 0 else 0 input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect') with torch.no_grad(): eval_net.trans.alpha_s = alpha_s eval_net.trans.alpha = alpha_i output = eval_net(input.cuda()**gamma) output = torch.clamp(output,0,1).cuda() output = output[:, :, :h, :w] enhanced_img = transforms.ToPILImage()(output.squeeze(0)) if score == 'Yes': im1 = np.array(enhanced_img) score_niqe = calculate_niqe(im1) return enhanced_img,score_niqe else: return enhanced_img,0 directory = "weights" pth_files = [ 'Generalization', 'Sony-Total-Dark', 'LOL-Blur', 'SICE', 'LOLv2-real-bestSSIM', 'LOLv2-real-bestPSNR', 'LOLv2-syn-wperc', 'LOLv2-syn-woperc', 'LOLv1-wperc', 'LOLv1-woperc' ] interface = gr.Interface( fn=process_image, inputs=[ gr.Image(label="Low-light Image", type="pil"), gr.Radio(choices=['Yes','No'],label="Image Score",info="Calculate NIQE, default is \"No\"."), gr.Radio(choices=pth_files,label="Model Weights",info="Choose your model. The best models are \"SICE\" and \"Generalization\"."), gr.Slider(0.1,5,label="gamma curve",step=0.01,value=1.0, info="Lower is lighter, best range is [0.5,2.5]."), gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."), gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is lighter.") ], outputs=[ gr.Image(label="Result", type="pil"), gr.Textbox(label="NIQE",info="Lower is better.") ], title="Light Amplification", description="The demo of paper \"HVI: A New Color Space for Low-light Image Enhancement\"", allow_flagging="never" ) interface.launch(share=False)