|
from __future__ import division |
|
|
|
import numpy as np |
|
import cv2 |
|
import random |
|
import torch |
|
import glob |
|
import os |
|
from random import choices |
|
from scipy.stats import poisson |
|
|
|
def Rawread(path,low=0): |
|
if path.endswith('.raw'): |
|
return read_img(path,low) |
|
if path.endswith('.npy'): |
|
return read_npy(path,low) |
|
if path.endswith('.png'): |
|
return read_png(path,low) |
|
|
|
def read_img(path,low): |
|
w = 4000 |
|
h = 3000 |
|
|
|
raw = np.fromfile(path,np.uint16) |
|
raw = raw.reshape((h,w)) |
|
raw = raw.astype(np.float32)-64 |
|
raw = rggb_raw(raw) |
|
raw = np.clip(raw, low, 959) |
|
|
|
return raw |
|
|
|
|
|
def read_npy(path,low): |
|
|
|
raw = np.load(path) |
|
|
|
if raw.shape[0] == 4: |
|
return raw * 959 |
|
raw = raw.astype(np.float32)-64 |
|
raw = rggb_raw(raw) |
|
raw = np.clip(raw, low, 959) |
|
return raw |
|
|
|
def read_rawpng(path, metadata): |
|
|
|
raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) |
|
|
|
|
|
|
|
raw = ((raw.astype(np.float32) - 256.) / (4095.- 256.)).clip(0, 1) |
|
|
|
raw = bayer2raw(raw, metadata) |
|
raw = np.clip(raw, 0., 1.) |
|
return raw |
|
|
|
def read_png(path, low): |
|
|
|
raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) |
|
|
|
if raw.shape[0] == 4: |
|
return raw * 959 |
|
raw = raw.astype(np.float32)-256 |
|
raw = rggb_raw(raw) |
|
raw = np.clip(raw, low, 4095) |
|
return raw |
|
|
|
def random_crop(frames_0,frames_1=None ,crop_size=128): |
|
|
|
F,C, H, W = frames_0.shape |
|
|
|
rnd_w = random.randint(0, W - crop_size) |
|
rnd_h = random.randint(0, H - crop_size) |
|
|
|
patch = frames_0[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] |
|
if not frames_1 is None: |
|
path1 = frames_1[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] |
|
return np.concatenate([patch,path1],axis=0) |
|
|
|
return patch |
|
|
|
def rggb_raw(raw): |
|
|
|
H, W = raw.shape |
|
raw = raw[None, ...] |
|
raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], |
|
raw[:, 0:H:2, 1:W:2], |
|
raw[:, 1:H:2, 0:W:2], |
|
raw[:, 1:H:2, 1:W:2]), axis=0) |
|
return raw_pack |
|
|
|
def bayer2raw(raw, metadata): |
|
|
|
H, W = raw.shape |
|
raw = raw[None, ...] |
|
if metadata['cfa_pattern'][0] == 0: |
|
|
|
raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], |
|
raw[:, 0:H:2, 1:W:2], |
|
raw[:, 1:H:2, 0:W:2], |
|
raw[:, 1:H:2, 1:W:2]), axis=0) |
|
else : |
|
|
|
raw_pack = np.concatenate((raw[:, 1:H:2, 1:W:2], |
|
raw[:, 0:H:2, 1:W:2], |
|
raw[:, 1:H:2, 0:W:2], |
|
raw[:, 0:H:2, 0:W:2]), axis=0) |
|
return raw_pack |
|
|
|
def raw_rggb(raws): |
|
|
|
C, H, W = raws.shape |
|
output = np.zeros((H * 2, W * 2)).astype(np.uint16) |
|
|
|
output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] |
|
output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] |
|
output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] |
|
output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] |
|
|
|
return output |
|
|
|
|
|
def raw_rggb_float32(raws): |
|
|
|
C, H, W = raws.shape |
|
output = np.zeros((H * 2, W * 2)).astype(np.float32) |
|
|
|
output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] |
|
output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] |
|
output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] |
|
output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] |
|
|
|
return output |
|
|
|
|
|
def depack_rggb_raws(raws): |
|
|
|
N, C, H, W = raws.shape |
|
output = torch.zeros((N, 1, H * 2, W * 2)) |
|
|
|
output[:, :, 0:2 * H:2, 0:2 * W:2] = raws[:, 0:1, :, :] |
|
output[:, :, 0:2 * H:2, 1:2 * W:2] = raws[:, 1:2, :, :] |
|
output[:, :, 1:2 * H:2, 0:2 * W:2] = raws[:, 2:3, :, :] |
|
output[:, :, 1:2 * H:2, 1:2 * W:2] = raws[:, 3:4, :, :] |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
IMAGETYPES = ('*.npy','*.raw',) |
|
|
|
def get_imagenames(seq_dir, pattern=None): |
|
""" Get ordered list of filenames |
|
""" |
|
files = [] |
|
for typ in IMAGETYPES: |
|
files.extend(glob.glob(os.path.join(seq_dir, typ))) |
|
|
|
|
|
if not pattern is None: |
|
ffiltered = [] |
|
ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] |
|
files = ffiltered |
|
del ffiltered |
|
|
|
|
|
files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) |
|
return files |
|
|
|
|
|
|
|
|
|
def get_imagenames(seq_dir, pattern=None): |
|
""" Get ordered list of filenames |
|
""" |
|
files = [] |
|
for typ in IMAGETYPES: |
|
files.extend(glob.glob(os.path.join(seq_dir, typ))) |
|
|
|
|
|
if not pattern is None: |
|
ffiltered = [] |
|
ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] |
|
files = ffiltered |
|
del ffiltered |
|
|
|
|
|
files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) |
|
return files |
|
|
|
def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=100): |
|
r""" Opens a sequence of images and expands it to even sizes if necesary |
|
Args: |
|
fpath: string, path to image sequence |
|
gray_mode: boolean, True indicating if images is to be open are in grayscale mode |
|
expand_if_needed: if True, the spatial dimensions will be expanded if |
|
size is odd |
|
expand_axis0: if True, output will have a fourth dimension |
|
max_num_fr: maximum number of frames to load |
|
Returns: |
|
seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even. |
|
The image gets normalized gets normalized to the range [0, 1]. |
|
expanded_h: True if original dim H was odd and image got expanded in this dimension. |
|
expanded_w: True if original dim W was odd and image got expanded in this dimension. |
|
""" |
|
|
|
files = get_imagenames(seq_dir) |
|
|
|
seq_list_raw = [] |
|
seq_list_raw_noise = [] |
|
print("\tOpen sequence in folder: ", seq_dir) |
|
for fpath in files[0:max_num_fr]: |
|
|
|
raw, raw_noise, expanded_h, expanded_w = open_image(fpath,\ |
|
gray_mode=gray_mode,\ |
|
expand_if_needed=expand_if_needed,\ |
|
expand_axis0=False) |
|
|
|
raw = rggb_raw(raw) |
|
raw_noise = rggb_raw(raw_noise) |
|
|
|
|
|
seq_list_raw.append(raw) |
|
seq_list_raw_noise.append(raw_noise) |
|
seq_raw = np.stack(seq_list_raw, axis=0) |
|
seq_raw_noise = np.stack(seq_list_raw_noise, axis=0) |
|
return seq_raw, seq_raw_noise, expanded_h, expanded_w |
|
|
|
def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True): |
|
r""" Opens an image and expands it if necesary |
|
Args: |
|
fpath: string, path of image file |
|
gray_mode: boolean, True indicating if image is to be open |
|
in grayscale mode |
|
expand_if_needed: if True, the spatial dimensions will be expanded if |
|
size is odd |
|
expand_axis0: if True, output will have a fourth dimension |
|
Returns: |
|
img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even. |
|
if expand_axis0=False, the output will have a shape CxHxW. |
|
The image gets normalized to the range [0, 1]. |
|
expanded_h: True if original dim H was odd and image got expanded in this dimension. |
|
expanded_w: True if original dim W was odd and image got expanded in this dimension. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w = 4000 |
|
h = 3000 |
|
raw_img = np.fromfile(fpath,dtype=np.uint16,count=w*h) |
|
raw_img = raw_img.reshape((h,w)).astype(np.float32)-64 |
|
raw_img = np.clip(raw_img, 0, 959) |
|
|
|
noise_fpath =fpath.replace('onlyraw_test_clean_raw','onlyraw_test_noise_raw') |
|
raw_img_noise = np.fromfile(noise_fpath,dtype=np.uint16,count=w*h) |
|
raw_img_noise = raw_img_noise.reshape((h,w)).astype(np.float32)-64 |
|
raw_img_noise = np.clip(raw_img_noise, 0, 959) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expanded_h = False |
|
expanded_w = False |
|
sh_im = raw_img.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if normalize_data: |
|
raw_img = normalize(raw_img) |
|
raw_img_noise = normalize(raw_img_noise) |
|
return raw_img, raw_img_noise, expanded_h, expanded_w |
|
|
|
|
|
def normalize(data): |
|
r"""Normalizes a unit8 image to a float32 image in the range [0, 1] |
|
|
|
Args: |
|
data: a unint8 numpy array to normalize from [0, 255] to [0, 1] |
|
""" |
|
return np.float32(data/(959)) |
|
|
|
|
|
def augment_cuda(batches, args, spynet=None): |
|
|
|
def _augment(img, hflip=True, rot=True): |
|
|
|
hflip = hflip and random.random() < 0.5 |
|
vflip = rot and random.random() < 0.5 |
|
|
|
k1 = np.random.randint(0, 4) |
|
if hflip: img = img.flip(-1) |
|
if vflip: img = img.flip(-2) |
|
|
|
img = torch.rot90(img, k=k1, dims=[-2, -1]) |
|
|
|
return img |
|
|
|
batches_aug = _augment(batches) |
|
|
|
if args.pair: |
|
noise = batches_aug[:,:args.frame,...]/959 |
|
clean = batches_aug[:,args.frame,...]/959 |
|
|
|
|
|
else: |
|
clean, noise = Noise_simulation(batches_aug,args) |
|
if not args.consistent_loss: |
|
clean = clean[:, args.frame // 2, ...] |
|
B, F, C , H, W = noise.shape |
|
noise = noise.reshape(B, F*C , H, W ) |
|
|
|
|
|
return clean, noise, None |
|
|
|
|
|
def Noise_simulation(batches_aug,args): |
|
batches_aug = batches_aug/959 |
|
batches_aug = torch.clamp(batches_aug , 0, 1) |
|
B = batches_aug.shape[0] |
|
batch_aug_mean = batches_aug.mean(dim=(1,2,3,4)) |
|
if args.need_Scaling: |
|
if args.sample_gain == 'type1': |
|
|
|
rand_avg = (torch.rand((B)) * 0.12 + 0.001).cuda(args.local_rank) |
|
if args.sample_gain == 'type2': |
|
rand_avg = Gain_Sampler(B).cuda(args.local_rank) |
|
|
|
coef = (batch_aug_mean / rand_avg).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
batch_aug_dark = torch.clamp(batches_aug / coef, 0, 1) |
|
else: |
|
batch_aug_dark = batches_aug |
|
|
|
a,b, again, dgain = random_noise_levels_nightimaging(B, args) |
|
batch_aug_dark,batch_aug_dark_noise = add_noise(args, batch_aug_dark,a.cuda(args.local_rank),b.cuda(args.local_rank),dgain.cuda(args.local_rank)) |
|
|
|
batch_aug_dark_noise = torch.clamp(batch_aug_dark_noise, -0.1, 1) |
|
|
|
|
|
return batch_aug_dark.float(), batch_aug_dark_noise.float() |
|
|
|
def random_noise_levels_nightimaging(B, args): |
|
|
|
g = torch.FloatTensor(B).uniform_(0, 125).int().long() |
|
noise_profile = torch.from_numpy(np.load('/data1/chengqihua/02_code/03_night_photogrphy/nightimage_v1/dataloader/json_all_2nd.npy')) |
|
|
|
a = noise_profile[g,0] |
|
b = noise_profile[g,1] |
|
|
|
return a, b, 1, 1*torch.ones(1) |
|
|
|
def random_noise_levels(B, args): |
|
ak1=0.05244803 |
|
ak2=0.01498041 |
|
bk1=0.00648923 |
|
bk2= 0.05899386 |
|
bk3 = 0.21520193 |
|
g = torch.FloatTensor(B).uniform_(args.min_gain, args.max_gain) |
|
|
|
maskA = g > 16 |
|
|
|
again = g.clone() |
|
again[maskA] = 16 |
|
|
|
maskB = g < 16 |
|
|
|
dgain = g.clone() / 16 |
|
dgain[maskB] = 1 |
|
|
|
|
|
|
|
a = ak1 * again + ak2 |
|
b = bk1 * again*again + bk2* again + bk3 |
|
|
|
return a, b, again, dgain |
|
|
|
def add_noise(args, image, a, b, dgain): |
|
|
|
dgain = dgain.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
|
a = a.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
|
b = b.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
|
|
|
|
|
B, F, C, H, W = image.size() |
|
|
|
image = image / dgain |
|
|
|
|
|
poisson_noisy_img = torch.poisson(image/a)*a |
|
|
|
gaussian_noise = torch.sqrt(b)*torch.randn(B, F, C, H, W).cuda(args.local_rank) |
|
|
|
noiseimg = poisson_noisy_img + gaussian_noise |
|
|
|
if args.usedgain : |
|
noiseimg = noiseimg * dgain |
|
image = image * dgain |
|
return image, noiseimg |
|
|
|
|
|
|
|
def normalize_augment(datain): |
|
'''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \ |
|
[N, num_frames*C. H, W] in [0., 1.]. It also returns the central (edited by cjm : now all frames) frame of the temporal \ |
|
patch as a ground truth. |
|
''' |
|
def transform(sample): |
|
|
|
do_nothing = lambda x: x |
|
do_nothing.__name__ = 'do_nothing' |
|
flipud = lambda x: torch.flip(x, dims=[2]) |
|
flipud.__name__ = 'flipup' |
|
rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3]) |
|
rot90.__name__ = 'rot90' |
|
rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2]) |
|
rot90_flipud.__name__ = 'rot90_flipud' |
|
rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3]) |
|
rot180.__name__ = 'rot180' |
|
rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2]) |
|
rot180_flipud.__name__ = 'rot180_flipud' |
|
rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3]) |
|
rot270.__name__ = 'rot270' |
|
rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2]) |
|
rot270_flipud.__name__ = 'rot270_flipud' |
|
add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \ |
|
std=(5/255.)).expand_as(x).to(x.device) |
|
add_csnt.__name__ = 'add_csnt' |
|
|
|
|
|
aug_list = [do_nothing, flipud, rot90, rot90_flipud, \ |
|
rot180, rot180_flipud, rot270, rot270_flipud, add_csnt] |
|
w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] |
|
transf = choices(aug_list, w_aug) |
|
|
|
|
|
return transf[0](sample) |
|
|
|
img_train = datain |
|
|
|
N, F, C, H, W = img_train.shape |
|
img_train = img_train.view(img_train.size()[0], -1, \ |
|
img_train.size()[-2], img_train.size()[-1]) / 255. |
|
|
|
|
|
img_train = transform(img_train) |
|
img_train = img_train.view(N, F, C, H, W) |
|
|
|
|
|
return img_train, img_train |
|
|
|
def Gain_Sampler(B): |
|
gain_dict = { |
|
'low':[5,35], |
|
'mid':[35,60], |
|
'high':[60,100] |
|
} |
|
|
|
level = ['low','mid','high'] |
|
sampled = np.random.choice(level,B,[0.7,0.2,0.1]) |
|
all = [] |
|
for index in sampled: |
|
all.append(torch.randint(gain_dict[index][0],gain_dict[index][1],(1,))) |
|
|
|
return torch.Tensor(all) |
|
|
|
def path_replace(path,args): |
|
for i in range(len(args.replace_left)): |
|
path = path.replace(args.replace_left[i],args.replace_right[i]) |
|
return path |