from __future__ import division import numpy as np import cv2 import random import torch import glob import os from random import choices from scipy.stats import poisson def Rawread(path,low=0): if path.endswith('.raw'): return read_img(path,low) if path.endswith('.npy'): return read_npy(path,low) if path.endswith('.png'): return read_png(path,low) def read_img(path,low): w = 4000 h = 3000 raw = np.fromfile(path,np.uint16) raw = raw.reshape((h,w)) raw = raw.astype(np.float32)-64 raw = rggb_raw(raw) raw = np.clip(raw, low, 959) return raw def read_npy(path,low): raw = np.load(path) if raw.shape[0] == 4: return raw * 959 raw = raw.astype(np.float32)-64 raw = rggb_raw(raw) raw = np.clip(raw, low, 959) return raw def read_rawpng(path, metadata): raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # if raw.shape[0] == 4: # return raw * 959 raw = ((raw.astype(np.float32) - 256.) / (4095.- 256.)).clip(0, 1) raw = bayer2raw(raw, metadata) raw = np.clip(raw, 0., 1.) return raw def read_png(path, low): raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) if raw.shape[0] == 4: return raw * 959 raw = raw.astype(np.float32)-256 raw = rggb_raw(raw) raw = np.clip(raw, low, 4095) return raw def random_crop(frames_0,frames_1=None ,crop_size=128): F,C, H, W = frames_0.shape rnd_w = random.randint(0, W - crop_size) rnd_h = random.randint(0, H - crop_size) patch = frames_0[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] if not frames_1 is None: path1 = frames_1[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size] return np.concatenate([patch,path1],axis=0) return patch def rggb_raw(raw): # pack RGGB Bayer raw to 4 channels H, W = raw.shape raw = raw[None, ...] raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], raw[:, 0:H:2, 1:W:2], raw[:, 1:H:2, 0:W:2], raw[:, 1:H:2, 1:W:2]), axis=0) return raw_pack def bayer2raw(raw, metadata): # pack RGGB Bayer raw to 4 channels H, W = raw.shape raw = raw[None, ...] if metadata['cfa_pattern'][0] == 0: # RGGB raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2], raw[:, 0:H:2, 1:W:2], raw[:, 1:H:2, 0:W:2], raw[:, 1:H:2, 1:W:2]), axis=0) else : # BGGR raw_pack = np.concatenate((raw[:, 1:H:2, 1:W:2], raw[:, 0:H:2, 1:W:2], raw[:, 1:H:2, 0:W:2], raw[:, 0:H:2, 0:W:2]), axis=0) return raw_pack def raw_rggb(raws): # depack 4 channels raw to RGGB Bayer C, H, W = raws.shape output = np.zeros((H * 2, W * 2)).astype(np.uint16) output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] return output def raw_rggb_float32(raws): # depack 4 channels raw to RGGB Bayer C, H, W = raws.shape output = np.zeros((H * 2, W * 2)).astype(np.float32) output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :] output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :] output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :] output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :] return output def depack_rggb_raws(raws): # depack 4 channels raw to RGGB Bayer N, C, H, W = raws.shape output = torch.zeros((N, 1, H * 2, W * 2)) output[:, :, 0:2 * H:2, 0:2 * W:2] = raws[:, 0:1, :, :] output[:, :, 0:2 * H:2, 1:2 * W:2] = raws[:, 1:2, :, :] output[:, :, 1:2 * H:2, 0:2 * W:2] = raws[:, 2:3, :, :] output[:, :, 1:2 * H:2, 1:2 * W:2] = raws[:, 3:4, :, :] return output # IMAGETYPES = ('*.bmp', '*.png', '*.jpg', '*.jpeg', '*.tif') IMAGETYPES = ('*.npy','*.raw',) #得加逗号 不然会拆分字符串 def get_imagenames(seq_dir, pattern=None): """ Get ordered list of filenames """ files = [] for typ in IMAGETYPES: files.extend(glob.glob(os.path.join(seq_dir, typ))) # filter filenames if not pattern is None: ffiltered = [] ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] files = ffiltered del ffiltered # sort filenames alphabetically files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) return files def get_imagenames(seq_dir, pattern=None): """ Get ordered list of filenames """ files = [] for typ in IMAGETYPES: files.extend(glob.glob(os.path.join(seq_dir, typ))) # filter filenames if not pattern is None: ffiltered = [] ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] files = ffiltered del ffiltered # sort filenames alphabetically files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) return files def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=100): r""" Opens a sequence of images and expands it to even sizes if necesary Args: fpath: string, path to image sequence gray_mode: boolean, True indicating if images is to be open are in grayscale mode expand_if_needed: if True, the spatial dimensions will be expanded if size is odd expand_axis0: if True, output will have a fourth dimension max_num_fr: maximum number of frames to load Returns: seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even. The image gets normalized gets normalized to the range [0, 1]. expanded_h: True if original dim H was odd and image got expanded in this dimension. expanded_w: True if original dim W was odd and image got expanded in this dimension. """ # Get ordered list of filenames files = get_imagenames(seq_dir) seq_list_raw = [] seq_list_raw_noise = [] print("\tOpen sequence in folder: ", seq_dir) for fpath in files[0:max_num_fr]: raw, raw_noise, expanded_h, expanded_w = open_image(fpath,\ gray_mode=gray_mode,\ expand_if_needed=expand_if_needed,\ expand_axis0=False) raw = rggb_raw(raw) raw_noise = rggb_raw(raw_noise) seq_list_raw.append(raw) seq_list_raw_noise.append(raw_noise) seq_raw = np.stack(seq_list_raw, axis=0) seq_raw_noise = np.stack(seq_list_raw_noise, axis=0) return seq_raw, seq_raw_noise, expanded_h, expanded_w def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True): r""" Opens an image and expands it if necesary Args: fpath: string, path of image file gray_mode: boolean, True indicating if image is to be open in grayscale mode expand_if_needed: if True, the spatial dimensions will be expanded if size is odd expand_axis0: if True, output will have a fourth dimension Returns: img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even. if expand_axis0=False, the output will have a shape CxHxW. The image gets normalized to the range [0, 1]. expanded_h: True if original dim H was odd and image got expanded in this dimension. expanded_w: True if original dim W was odd and image got expanded in this dimension. """ # if not gray_mode: # # Open image as a CxHxW torch.Tensor # img = cv2.imread(fpath) # # from HxWxC to CxHxW, RGB image # img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) # else: # # from HxWxC to CxHxW grayscale image (C=1) # img = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE) # 测试真实的图片 # raw_img = ((np.fromfile(fpath,np.uint16).astype(np.float32))*4833)/2048 # raw_img = np.clip(raw_img-64, 0, 1023-64) # raw_img = raw_img.reshape((3000,4000)) # raw_img = np.load(fpath).astype(np.float32)-64 w = 4000 h = 3000 raw_img = np.fromfile(fpath,dtype=np.uint16,count=w*h) raw_img = raw_img.reshape((h,w)).astype(np.float32)-64 raw_img = np.clip(raw_img, 0, 959) noise_fpath =fpath.replace('onlyraw_test_clean_raw','onlyraw_test_noise_raw') raw_img_noise = np.fromfile(noise_fpath,dtype=np.uint16,count=w*h) raw_img_noise = raw_img_noise.reshape((h,w)).astype(np.float32)-64 raw_img_noise = np.clip(raw_img_noise, 0, 959) #blc # if expand_axis0: # img = np.expand_dims(img, 0) # Handle odd sizes expanded_h = False expanded_w = False sh_im = raw_img.shape # if expand_if_needed: # if sh_im[-2]%2 == 1: # expanded_h = True # if expand_axis0: # img = np.concatenate((img, \ # img[:, :, -1, :][:, :, np.newaxis, :]), axis=2) # else: # img = np.concatenate((img, \ # img[:, -1, :][:, np.newaxis, :]), axis=1) # if sh_im[-1]%2 == 1: # expanded_w = True # if expand_axis0: # img = np.concatenate((img, \ # img[:, :, :, -1][:, :, :, np.newaxis]), axis=3) # else: # img = np.concatenate((img, \ # img[:, :, -1][:, :, np.newaxis]), axis=2) if normalize_data: raw_img = normalize(raw_img) raw_img_noise = normalize(raw_img_noise) return raw_img, raw_img_noise, expanded_h, expanded_w def normalize(data): r"""Normalizes a unit8 image to a float32 image in the range [0, 1] Args: data: a unint8 numpy array to normalize from [0, 255] to [0, 1] """ return np.float32(data/(959)) def augment_cuda(batches, args, spynet=None): def _augment(img, hflip=True, rot=True): hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 # rot90 = rot and random.random() < 0.5 k1 = np.random.randint(0, 4) #0,1,2,3 if hflip: img = img.flip(-1) if vflip: img = img.flip(-2) img = torch.rot90(img, k=k1, dims=[-2, -1]) return img batches_aug = _augment(batches) if args.pair: noise = batches_aug[:,:args.frame,...]/959 clean = batches_aug[:,args.frame,...]/959 #if args.scene != 'noisedata' else batches_aug[:,args.frame,...] else: clean, noise = Noise_simulation(batches_aug,args) if not args.consistent_loss: clean = clean[:, args.frame // 2, ...] B, F, C , H, W = noise.shape noise = noise.reshape(B, F*C , H, W ) return clean, noise, None def Noise_simulation(batches_aug,args): batches_aug = batches_aug/959 batches_aug = torch.clamp(batches_aug , 0, 1) B = batches_aug.shape[0] batch_aug_mean = batches_aug.mean(dim=(1,2,3,4)) if args.need_Scaling: if args.sample_gain == 'type1': # rand_avg = torch.randint(args.luminance_low, args.luminance_high ,(B, )).cuda(args.local_rank) rand_avg = (torch.rand((B)) * 0.12 + 0.001).cuda(args.local_rank) if args.sample_gain == 'type2': rand_avg = Gain_Sampler(B).cuda(args.local_rank) coef = (batch_aug_mean / rand_avg).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) batch_aug_dark = torch.clamp(batches_aug / coef, 0, 1) else: batch_aug_dark = batches_aug a,b, again, dgain = random_noise_levels_nightimaging(B, args) batch_aug_dark,batch_aug_dark_noise = add_noise(args, batch_aug_dark,a.cuda(args.local_rank),b.cuda(args.local_rank),dgain.cuda(args.local_rank)) batch_aug_dark_noise = torch.clamp(batch_aug_dark_noise, -0.1, 1) # print(batch_aug_dark_noise.mean()) return batch_aug_dark.float(), batch_aug_dark_noise.float() def random_noise_levels_nightimaging(B, args): # print('use new') g = torch.FloatTensor(B).uniform_(0, 125).int().long() noise_profile = torch.from_numpy(np.load('/data1/chengqihua/02_code/03_night_photogrphy/nightimage_v1/dataloader/json_all_2nd.npy')) a = noise_profile[g,0] b = noise_profile[g,1] return a, b, 1, 1*torch.ones(1) def random_noise_levels(B, args): ak1=0.05244803 ak2=0.01498041 bk1=0.00648923 bk2= 0.05899386 bk3 = 0.21520193 g = torch.FloatTensor(B).uniform_(args.min_gain, args.max_gain) maskA = g > 16 again = g.clone() again[maskA] = 16 maskB = g < 16 dgain = g.clone() / 16 dgain[maskB] = 1 a = ak1 * again + ak2 b = bk1 * again*again + bk2* again + bk3 return a, b, again, dgain def add_noise(args, image, a, b, dgain): dgain = dgain.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) a = a.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) b = b.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) B, F, C, H, W = image.size() image = image / dgain poisson_noisy_img = torch.poisson(image/a)*a gaussian_noise = torch.sqrt(b)*torch.randn(B, F, C, H, W).cuda(args.local_rank) noiseimg = poisson_noisy_img + gaussian_noise if args.usedgain : noiseimg = noiseimg * dgain image = image * dgain return image, noiseimg def normalize_augment(datain): '''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \ [N, num_frames*C. H, W] in [0., 1.]. It also returns the central (edited by cjm : now all frames) frame of the temporal \ patch as a ground truth. ''' def transform(sample): # define transformations do_nothing = lambda x: x do_nothing.__name__ = 'do_nothing' flipud = lambda x: torch.flip(x, dims=[2]) flipud.__name__ = 'flipup' rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3]) rot90.__name__ = 'rot90' rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2]) rot90_flipud.__name__ = 'rot90_flipud' rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3]) rot180.__name__ = 'rot180' rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2]) rot180_flipud.__name__ = 'rot180_flipud' rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3]) rot270.__name__ = 'rot270' rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2]) rot270_flipud.__name__ = 'rot270_flipud' add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \ std=(5/255.)).expand_as(x).to(x.device) add_csnt.__name__ = 'add_csnt' # define transformations and their frequency, then pick one. aug_list = [do_nothing, flipud, rot90, rot90_flipud, \ rot180, rot180_flipud, rot270, rot270_flipud, add_csnt] w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] # one fourth chances to do_nothing transf = choices(aug_list, w_aug) # transform all images in array return transf[0](sample) img_train = datain #torch.Size([8, 11, 3, 96, 96]) # convert to [N, num_frames*C. H, W] in [0., 1.] from [N, num_frames, C. H, W] in [0., 255.] N, F, C, H, W = img_train.shape img_train = img_train.view(img_train.size()[0], -1, \ img_train.size()[-2], img_train.size()[-1]) / 255. # torch.Size([8, 33, 96, 96]) #augment img_train = transform(img_train) img_train = img_train.view(N, F, C, H, W) # extract ground truth (central frame) # gt_train = img_train[:, 3*ctrl_fr_idx:3*ctrl_fr_idx+3, :, :] return img_train, img_train def Gain_Sampler(B): gain_dict = { 'low':[5,35], 'mid':[35,60], 'high':[60,100] } level = ['low','mid','high'] sampled = np.random.choice(level,B,[0.7,0.2,0.1]) all = [] for index in sampled: all.append(torch.randint(gain_dict[index][0],gain_dict[index][1],(1,))) return torch.Tensor(all) def path_replace(path,args): for i in range(len(args.replace_left)): path = path.replace(args.replace_left[i],args.replace_right[i]) return path