Intro

These are my efforts to train a real-world usable Cascaded Gaze image denoising network.

denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch.

Models

v1

  • ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
  • I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.

Loading v1

from denoise_util import CascadedGaze
from safetensors.torch import load_file

device = "cuda"

img_channel = 3
width = 60
enc_blks = [2, 2, 4, 6]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]
GCE_CONVS_nums = [3,3,2,2]

model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums)

state_dict = load_file("models/v1.safetensors")
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()

Usage

  • Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
  • You'll need to make ammendments to prevent the batches from being too large for your device.
  • presumes the model was already loaded with code above.
import torch
from PIL import Image
import torchvision
from blended_tiling import TilingModule

def toimg(tensor):
    tensor = torch.clamp(tensor, 0.0, 1.0)
    tensor = tensor * 255
    tensor = tensor.byte()
    return torchvision.transforms.functional.to_pil_image(tensor)

# nb: if rgba inputs are anticipated, this won't be sufficient.
pil_image = Image.open("input.jpg").convert("RGB")

tiling_module = TilingModule(
    tile_size=[256, 256],
    tile_overlap=[0.1, 0.1], # you can configure this to taste
    base_size=pil_image.size,
)

tensor = torchvision.transforms.functional.to_tensor(pil_image)
tensor = torch.unsqueeze(tensor,0)
tiles = tiling_module.split_into_tiles(tensor)
tiles = tiles.to(device)
result = model(tiles).cpu()
result = tiling_module.rebuild_with_masks(result).squeeze()

pil_result = toimg(result)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.