Intro
These are my efforts to train a real-world usable Cascaded Gaze image denoising network.
denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch.
Models
v1
- ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
- I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.
Loading v1
from denoise_util import CascadedGaze
from safetensors.torch import load_file
device = "cuda"
img_channel = 3
width = 60
enc_blks = [2, 2, 4, 6]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]
GCE_CONVS_nums = [3,3,2,2]
model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num,
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums)
state_dict = load_file("models/v1.safetensors")
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()
Usage
- Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
- You'll need to make ammendments to prevent the batches from being too large for your device.
- presumes the model was already loaded with code above.
import torch
from PIL import Image
import torchvision
from blended_tiling import TilingModule
def toimg(tensor):
tensor = torch.clamp(tensor, 0.0, 1.0)
tensor = tensor * 255
tensor = tensor.byte()
return torchvision.transforms.functional.to_pil_image(tensor)
# nb: if rgba inputs are anticipated, this won't be sufficient.
pil_image = Image.open("input.jpg").convert("RGB")
tiling_module = TilingModule(
tile_size=[256, 256],
tile_overlap=[0.1, 0.1], # you can configure this to taste
base_size=pil_image.size,
)
tensor = torchvision.transforms.functional.to_tensor(pil_image)
tensor = torch.unsqueeze(tensor,0)
tiles = tiling_module.split_into_tiles(tensor)
tiles = tiles.to(device)
result = model(tiles).cpu()
result = tiling_module.rebuild_with_masks(result).squeeze()
pil_result = toimg(result)
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
HF Inference API was unable to determine this model's library.