NPRC24 / IIR-Lab /dataloader /data_utils.py
Artyom
IIRLab
6721043 verified
from __future__ import division
import numpy as np
import cv2
import random
import torch
import glob
import os
from random import choices
from scipy.stats import poisson
def Rawread(path,low=0):
if path.endswith('.raw'):
return read_img(path,low)
if path.endswith('.npy'):
return read_npy(path,low)
if path.endswith('.png'):
return read_png(path,low)
def read_img(path,low):
w = 4000
h = 3000
raw = np.fromfile(path,np.uint16)
raw = raw.reshape((h,w))
raw = raw.astype(np.float32)-64
raw = rggb_raw(raw)
raw = np.clip(raw, low, 959)
return raw
def read_npy(path,low):
raw = np.load(path)
if raw.shape[0] == 4:
return raw * 959
raw = raw.astype(np.float32)-64
raw = rggb_raw(raw)
raw = np.clip(raw, low, 959)
return raw
def read_rawpng(path, metadata):
raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
# if raw.shape[0] == 4:
# return raw * 959
raw = ((raw.astype(np.float32) - 256.) / (4095.- 256.)).clip(0, 1)
raw = bayer2raw(raw, metadata)
raw = np.clip(raw, 0., 1.)
return raw
def read_png(path, low):
raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
if raw.shape[0] == 4:
return raw * 959
raw = raw.astype(np.float32)-256
raw = rggb_raw(raw)
raw = np.clip(raw, low, 4095)
return raw
def random_crop(frames_0,frames_1=None ,crop_size=128):
F,C, H, W = frames_0.shape
rnd_w = random.randint(0, W - crop_size)
rnd_h = random.randint(0, H - crop_size)
patch = frames_0[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size]
if not frames_1 is None:
path1 = frames_1[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size]
return np.concatenate([patch,path1],axis=0)
return patch
def rggb_raw(raw):
# pack RGGB Bayer raw to 4 channels
H, W = raw.shape
raw = raw[None, ...]
raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2],
raw[:, 0:H:2, 1:W:2],
raw[:, 1:H:2, 0:W:2],
raw[:, 1:H:2, 1:W:2]), axis=0)
return raw_pack
def bayer2raw(raw, metadata):
# pack RGGB Bayer raw to 4 channels
H, W = raw.shape
raw = raw[None, ...]
if metadata['cfa_pattern'][0] == 0:
# RGGB
raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2],
raw[:, 0:H:2, 1:W:2],
raw[:, 1:H:2, 0:W:2],
raw[:, 1:H:2, 1:W:2]), axis=0)
else :
# BGGR
raw_pack = np.concatenate((raw[:, 1:H:2, 1:W:2],
raw[:, 0:H:2, 1:W:2],
raw[:, 1:H:2, 0:W:2],
raw[:, 0:H:2, 0:W:2]), axis=0)
return raw_pack
def raw_rggb(raws):
# depack 4 channels raw to RGGB Bayer
C, H, W = raws.shape
output = np.zeros((H * 2, W * 2)).astype(np.uint16)
output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :]
output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :]
output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :]
output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :]
return output
def raw_rggb_float32(raws):
# depack 4 channels raw to RGGB Bayer
C, H, W = raws.shape
output = np.zeros((H * 2, W * 2)).astype(np.float32)
output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :]
output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :]
output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :]
output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :]
return output
def depack_rggb_raws(raws):
# depack 4 channels raw to RGGB Bayer
N, C, H, W = raws.shape
output = torch.zeros((N, 1, H * 2, W * 2))
output[:, :, 0:2 * H:2, 0:2 * W:2] = raws[:, 0:1, :, :]
output[:, :, 0:2 * H:2, 1:2 * W:2] = raws[:, 1:2, :, :]
output[:, :, 1:2 * H:2, 0:2 * W:2] = raws[:, 2:3, :, :]
output[:, :, 1:2 * H:2, 1:2 * W:2] = raws[:, 3:4, :, :]
return output
# IMAGETYPES = ('*.bmp', '*.png', '*.jpg', '*.jpeg', '*.tif')
IMAGETYPES = ('*.npy','*.raw',) #得加逗号 不然会拆分字符串
def get_imagenames(seq_dir, pattern=None):
""" Get ordered list of filenames
"""
files = []
for typ in IMAGETYPES:
files.extend(glob.glob(os.path.join(seq_dir, typ)))
# filter filenames
if not pattern is None:
ffiltered = []
ffiltered = [f for f in files if pattern in os.path.split(f)[-1]]
files = ffiltered
del ffiltered
# sort filenames alphabetically
files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
return files
def get_imagenames(seq_dir, pattern=None):
""" Get ordered list of filenames
"""
files = []
for typ in IMAGETYPES:
files.extend(glob.glob(os.path.join(seq_dir, typ)))
# filter filenames
if not pattern is None:
ffiltered = []
ffiltered = [f for f in files if pattern in os.path.split(f)[-1]]
files = ffiltered
del ffiltered
# sort filenames alphabetically
files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
return files
def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=100):
r""" Opens a sequence of images and expands it to even sizes if necesary
Args:
fpath: string, path to image sequence
gray_mode: boolean, True indicating if images is to be open are in grayscale mode
expand_if_needed: if True, the spatial dimensions will be expanded if
size is odd
expand_axis0: if True, output will have a fourth dimension
max_num_fr: maximum number of frames to load
Returns:
seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even.
The image gets normalized gets normalized to the range [0, 1].
expanded_h: True if original dim H was odd and image got expanded in this dimension.
expanded_w: True if original dim W was odd and image got expanded in this dimension.
"""
# Get ordered list of filenames
files = get_imagenames(seq_dir)
seq_list_raw = []
seq_list_raw_noise = []
print("\tOpen sequence in folder: ", seq_dir)
for fpath in files[0:max_num_fr]:
raw, raw_noise, expanded_h, expanded_w = open_image(fpath,\
gray_mode=gray_mode,\
expand_if_needed=expand_if_needed,\
expand_axis0=False)
raw = rggb_raw(raw)
raw_noise = rggb_raw(raw_noise)
seq_list_raw.append(raw)
seq_list_raw_noise.append(raw_noise)
seq_raw = np.stack(seq_list_raw, axis=0)
seq_raw_noise = np.stack(seq_list_raw_noise, axis=0)
return seq_raw, seq_raw_noise, expanded_h, expanded_w
def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True):
r""" Opens an image and expands it if necesary
Args:
fpath: string, path of image file
gray_mode: boolean, True indicating if image is to be open
in grayscale mode
expand_if_needed: if True, the spatial dimensions will be expanded if
size is odd
expand_axis0: if True, output will have a fourth dimension
Returns:
img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even.
if expand_axis0=False, the output will have a shape CxHxW.
The image gets normalized to the range [0, 1].
expanded_h: True if original dim H was odd and image got expanded in this dimension.
expanded_w: True if original dim W was odd and image got expanded in this dimension.
"""
# if not gray_mode:
# # Open image as a CxHxW torch.Tensor
# img = cv2.imread(fpath)
# # from HxWxC to CxHxW, RGB image
# img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
# else:
# # from HxWxC to CxHxW grayscale image (C=1)
# img = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)
# 测试真实的图片
# raw_img = ((np.fromfile(fpath,np.uint16).astype(np.float32))*4833)/2048
# raw_img = np.clip(raw_img-64, 0, 1023-64)
# raw_img = raw_img.reshape((3000,4000))
# raw_img = np.load(fpath).astype(np.float32)-64
w = 4000
h = 3000
raw_img = np.fromfile(fpath,dtype=np.uint16,count=w*h)
raw_img = raw_img.reshape((h,w)).astype(np.float32)-64
raw_img = np.clip(raw_img, 0, 959)
noise_fpath =fpath.replace('onlyraw_test_clean_raw','onlyraw_test_noise_raw')
raw_img_noise = np.fromfile(noise_fpath,dtype=np.uint16,count=w*h)
raw_img_noise = raw_img_noise.reshape((h,w)).astype(np.float32)-64
raw_img_noise = np.clip(raw_img_noise, 0, 959)
#blc
# if expand_axis0:
# img = np.expand_dims(img, 0)
# Handle odd sizes
expanded_h = False
expanded_w = False
sh_im = raw_img.shape
# if expand_if_needed:
# if sh_im[-2]%2 == 1:
# expanded_h = True
# if expand_axis0:
# img = np.concatenate((img, \
# img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
# else:
# img = np.concatenate((img, \
# img[:, -1, :][:, np.newaxis, :]), axis=1)
# if sh_im[-1]%2 == 1:
# expanded_w = True
# if expand_axis0:
# img = np.concatenate((img, \
# img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
# else:
# img = np.concatenate((img, \
# img[:, :, -1][:, :, np.newaxis]), axis=2)
if normalize_data:
raw_img = normalize(raw_img)
raw_img_noise = normalize(raw_img_noise)
return raw_img, raw_img_noise, expanded_h, expanded_w
def normalize(data):
r"""Normalizes a unit8 image to a float32 image in the range [0, 1]
Args:
data: a unint8 numpy array to normalize from [0, 255] to [0, 1]
"""
return np.float32(data/(959))
def augment_cuda(batches, args, spynet=None):
def _augment(img, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
# rot90 = rot and random.random() < 0.5
k1 = np.random.randint(0, 4) #0,1,2,3
if hflip: img = img.flip(-1)
if vflip: img = img.flip(-2)
img = torch.rot90(img, k=k1, dims=[-2, -1])
return img
batches_aug = _augment(batches)
if args.pair:
noise = batches_aug[:,:args.frame,...]/959
clean = batches_aug[:,args.frame,...]/959 #if args.scene != 'noisedata' else batches_aug[:,args.frame,...]
else:
clean, noise = Noise_simulation(batches_aug,args)
if not args.consistent_loss:
clean = clean[:, args.frame // 2, ...]
B, F, C , H, W = noise.shape
noise = noise.reshape(B, F*C , H, W )
return clean, noise, None
def Noise_simulation(batches_aug,args):
batches_aug = batches_aug/959
batches_aug = torch.clamp(batches_aug , 0, 1)
B = batches_aug.shape[0]
batch_aug_mean = batches_aug.mean(dim=(1,2,3,4))
if args.need_Scaling:
if args.sample_gain == 'type1':
# rand_avg = torch.randint(args.luminance_low, args.luminance_high ,(B, )).cuda(args.local_rank)
rand_avg = (torch.rand((B)) * 0.12 + 0.001).cuda(args.local_rank)
if args.sample_gain == 'type2':
rand_avg = Gain_Sampler(B).cuda(args.local_rank)
coef = (batch_aug_mean / rand_avg).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
batch_aug_dark = torch.clamp(batches_aug / coef, 0, 1)
else:
batch_aug_dark = batches_aug
a,b, again, dgain = random_noise_levels_nightimaging(B, args)
batch_aug_dark,batch_aug_dark_noise = add_noise(args, batch_aug_dark,a.cuda(args.local_rank),b.cuda(args.local_rank),dgain.cuda(args.local_rank))
batch_aug_dark_noise = torch.clamp(batch_aug_dark_noise, -0.1, 1)
# print(batch_aug_dark_noise.mean())
return batch_aug_dark.float(), batch_aug_dark_noise.float()
def random_noise_levels_nightimaging(B, args):
# print('use new')
g = torch.FloatTensor(B).uniform_(0, 125).int().long()
noise_profile = torch.from_numpy(np.load('/data1/chengqihua/02_code/03_night_photogrphy/nightimage_v1/dataloader/json_all_2nd.npy'))
a = noise_profile[g,0]
b = noise_profile[g,1]
return a, b, 1, 1*torch.ones(1)
def random_noise_levels(B, args):
ak1=0.05244803
ak2=0.01498041
bk1=0.00648923
bk2= 0.05899386
bk3 = 0.21520193
g = torch.FloatTensor(B).uniform_(args.min_gain, args.max_gain)
maskA = g > 16
again = g.clone()
again[maskA] = 16
maskB = g < 16
dgain = g.clone() / 16
dgain[maskB] = 1
a = ak1 * again + ak2
b = bk1 * again*again + bk2* again + bk3
return a, b, again, dgain
def add_noise(args, image, a, b, dgain):
dgain = dgain.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
a = a.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
b = b.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
B, F, C, H, W = image.size()
image = image / dgain
poisson_noisy_img = torch.poisson(image/a)*a
gaussian_noise = torch.sqrt(b)*torch.randn(B, F, C, H, W).cuda(args.local_rank)
noiseimg = poisson_noisy_img + gaussian_noise
if args.usedgain :
noiseimg = noiseimg * dgain
image = image * dgain
return image, noiseimg
def normalize_augment(datain):
'''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \
[N, num_frames*C. H, W] in [0., 1.]. It also returns the central (edited by cjm : now all frames) frame of the temporal \
patch as a ground truth.
'''
def transform(sample):
# define transformations
do_nothing = lambda x: x
do_nothing.__name__ = 'do_nothing'
flipud = lambda x: torch.flip(x, dims=[2])
flipud.__name__ = 'flipup'
rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3])
rot90.__name__ = 'rot90'
rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2])
rot90_flipud.__name__ = 'rot90_flipud'
rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3])
rot180.__name__ = 'rot180'
rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2])
rot180_flipud.__name__ = 'rot180_flipud'
rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3])
rot270.__name__ = 'rot270'
rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2])
rot270_flipud.__name__ = 'rot270_flipud'
add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \
std=(5/255.)).expand_as(x).to(x.device)
add_csnt.__name__ = 'add_csnt'
# define transformations and their frequency, then pick one.
aug_list = [do_nothing, flipud, rot90, rot90_flipud, \
rot180, rot180_flipud, rot270, rot270_flipud, add_csnt]
w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] # one fourth chances to do_nothing
transf = choices(aug_list, w_aug)
# transform all images in array
return transf[0](sample)
img_train = datain #torch.Size([8, 11, 3, 96, 96])
# convert to [N, num_frames*C. H, W] in [0., 1.] from [N, num_frames, C. H, W] in [0., 255.]
N, F, C, H, W = img_train.shape
img_train = img_train.view(img_train.size()[0], -1, \
img_train.size()[-2], img_train.size()[-1]) / 255. # torch.Size([8, 33, 96, 96])
#augment
img_train = transform(img_train)
img_train = img_train.view(N, F, C, H, W)
# extract ground truth (central frame)
# gt_train = img_train[:, 3*ctrl_fr_idx:3*ctrl_fr_idx+3, :, :]
return img_train, img_train
def Gain_Sampler(B):
gain_dict = {
'low':[5,35],
'mid':[35,60],
'high':[60,100]
}
level = ['low','mid','high']
sampled = np.random.choice(level,B,[0.7,0.2,0.1])
all = []
for index in sampled:
all.append(torch.randint(gain_dict[index][0],gain_dict[index][1],(1,)))
return torch.Tensor(all)
def path_replace(path,args):
for i in range(len(args.replace_left)):
path = path.replace(args.replace_left[i],args.replace_right[i])
return path