import gradio as gr import torch import torchvision.transforms as transforms import numpy as np from PIL import Image from model.flol import create_model device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #define some auxiliary functions pil_to_tensor = transforms.ToTensor() # Define a dictionary to map image filenames to weight files image_to_weights = ['./weights/flolv2_UHDLL.pt','./weights/flolv2_all_111439.pt'] # Initial model setup (without weights) model = create_model() def load_img(filename): img = Image.open(filename).convert("RGB") img_tensor = pil_to_tensor(img) return img_tensor def process_img(image, UHD_LL_model): # Select the correct weight file based on the image filename # filename = image.name.split("/")[-1] # if filename in image_to_weights: model_path = image_to_weights[0] if UHD_LL_model else image_to_weights[1] checkpoints = torch.load(model_path, map_location=device) model.load_state_dict(checkpoints['params']) model.to(device) img = np.array(image) img = img / 255. # Normalize to [0, 1] img = img.astype(np.float32) y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0., 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8 return Image.fromarray(restored_img) title = "Fast Baselines for Real-World Low-Light Enhancement 🌠⚡🎆" description = ''' ## [Github Repository](https://github.com/cidautai/NAFourNet) [Juan Carlos Benito](https://github.com/juaben) Fundación Cidaut > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. **This demo expects an image with some degradations. If the checkbox is selected, the program will load the model related to UHD-LL dataset, if not it will load LOLv2-Real weight file.** Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
''' examples = [ ['images/low00772.png'], ['images/low00723.png'], ['images/425_UHD_LL.JPG'], ['images/1778_UHD_LL.JPG'], ['images/1791_UHD_LL.JPG'] ] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn = process_img, inputs = [gr.Image(type = 'pil', label = 'input'), 'checkbox'], outputs = [gr.Image(type='pil', label = 'output')], title = title, description = description, examples = examples, css = css ) if __name__ == '__main__': demo.launch()