from __future__ import division

import numpy as np
import cv2
import random
import torch
import glob
import os
from random import choices
from scipy.stats import poisson

def Rawread(path,low=0):
    if path.endswith('.raw'):
          return read_img(path,low)
    if path.endswith('.npy'):
          return read_npy(path,low)
    if path.endswith('.png'):
          return read_png(path,low)
          
def read_img(path,low):
    w = 4000
    h = 3000

    raw = np.fromfile(path,np.uint16)
    raw = raw.reshape((h,w))
    raw = raw.astype(np.float32)-64
    raw = rggb_raw(raw)
    raw = np.clip(raw, low, 959)

    return raw


def read_npy(path,low):
    
    raw = np.load(path)

    if raw.shape[0] == 4:
        return raw * 959
    raw = raw.astype(np.float32)-64
    raw = rggb_raw(raw)
    raw = np.clip(raw, low, 959)
    return raw

def read_rawpng(path, metadata):
    
    raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)

    # if raw.shape[0] == 4:
    #     return raw * 959
    raw = ((raw.astype(np.float32) - 256.) / (4095.- 256.)).clip(0, 1)
    
    raw = bayer2raw(raw, metadata)
    raw = np.clip(raw, 0., 1.)
    return raw

def read_png(path, low):
    
    raw = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)

    if raw.shape[0] == 4:
        return raw * 959
    raw = raw.astype(np.float32)-256
    raw = rggb_raw(raw)
    raw = np.clip(raw, low, 4095)
    return raw

def random_crop(frames_0,frames_1=None ,crop_size=128):

    F,C, H, W = frames_0.shape

    rnd_w = random.randint(0, W - crop_size)
    rnd_h = random.randint(0, H - crop_size)

    patch = frames_0[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size]
    if not frames_1 is None:
        path1 = frames_1[..., rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size]
        return np.concatenate([patch,path1],axis=0)

    return patch

def rggb_raw(raw):
    # pack RGGB Bayer raw to 4 channels
    H, W = raw.shape
    raw = raw[None, ...]
    raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2],
                               raw[:, 0:H:2, 1:W:2],
                               raw[:, 1:H:2, 0:W:2],
                               raw[:, 1:H:2, 1:W:2]), axis=0)
    return raw_pack

def bayer2raw(raw, metadata):
    # pack RGGB Bayer raw to 4 channels
    H, W = raw.shape
    raw = raw[None, ...]
    if metadata['cfa_pattern'][0] == 0:
        # RGGB
        raw_pack = np.concatenate((raw[:, 0:H:2, 0:W:2],
                                raw[:, 0:H:2, 1:W:2],
                                raw[:, 1:H:2, 0:W:2],
                                raw[:, 1:H:2, 1:W:2]), axis=0)
    else :
        # BGGR
        raw_pack = np.concatenate((raw[:, 1:H:2, 1:W:2],
                                raw[:, 0:H:2, 1:W:2],
                                raw[:, 1:H:2, 0:W:2],
                                raw[:, 0:H:2, 0:W:2]), axis=0)
    return raw_pack

def raw_rggb(raws):
    # depack 4 channels raw to RGGB Bayer
    C, H, W = raws.shape
    output = np.zeros((H * 2, W * 2)).astype(np.uint16)

    output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :]
    output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :]
    output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :]
    output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :]

    return output


def raw_rggb_float32(raws):
    # depack 4 channels raw to RGGB Bayer
    C, H, W = raws.shape
    output = np.zeros((H * 2, W * 2)).astype(np.float32)

    output[0:2 * H:2, 0:2 * W:2] = raws[0:1, :, :]
    output[0:2 * H:2, 1:2 * W:2] = raws[1:2, :, :]
    output[1:2 * H:2, 0:2 * W:2] = raws[2:3, :, :]
    output[1:2 * H:2, 1:2 * W:2] = raws[3:4, :, :]

    return output


def depack_rggb_raws(raws):
    # depack 4 channels raw to RGGB Bayer
    N, C, H, W = raws.shape
    output = torch.zeros((N, 1, H * 2, W * 2))

    output[:, :, 0:2 * H:2, 0:2 * W:2] = raws[:, 0:1, :, :]
    output[:, :, 0:2 * H:2, 1:2 * W:2] = raws[:, 1:2, :, :]
    output[:, :, 1:2 * H:2, 0:2 * W:2] = raws[:, 2:3, :, :]
    output[:, :, 1:2 * H:2, 1:2 * W:2] = raws[:, 3:4, :, :]

    return output



# IMAGETYPES = ('*.bmp', '*.png', '*.jpg', '*.jpeg', '*.tif')
IMAGETYPES = ('*.npy','*.raw',)  #得加逗号  不然会拆分字符串

def get_imagenames(seq_dir, pattern=None):
	""" Get ordered list of filenames
	"""
	files = []
	for typ in IMAGETYPES:
		files.extend(glob.glob(os.path.join(seq_dir, typ)))

	# filter filenames
	if not pattern is None:
		ffiltered = []
		ffiltered = [f for f in files if pattern in os.path.split(f)[-1]]
		files = ffiltered
		del ffiltered

	# sort filenames alphabetically
	files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
	return files




def get_imagenames(seq_dir, pattern=None):
    """ Get ordered list of filenames
    """
    files = []
    for typ in IMAGETYPES:
        files.extend(glob.glob(os.path.join(seq_dir, typ)))

    # filter filenames
    if not pattern is None:
        ffiltered = []
        ffiltered = [f for f in files if pattern in os.path.split(f)[-1]]
        files = ffiltered
        del ffiltered

    # sort filenames alphabetically
    files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    return files

def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=100):
    r""" Opens a sequence of images and expands it to even sizes if necesary
    Args:
        fpath: string, path to image sequence
        gray_mode: boolean, True indicating if images is to be open are in grayscale mode
        expand_if_needed: if True, the spatial dimensions will be expanded if
            size is odd
        expand_axis0: if True, output will have a fourth dimension
        max_num_fr: maximum number of frames to load
    Returns:
        seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even.
            The image gets normalized gets normalized to the range [0, 1].
        expanded_h: True if original dim H was odd and image got expanded in this dimension.
        expanded_w: True if original dim W was odd and image got expanded in this dimension.
    """
    # Get ordered list of filenames
    files = get_imagenames(seq_dir)

    seq_list_raw = []
    seq_list_raw_noise = []
    print("\tOpen sequence in folder: ", seq_dir)
    for fpath in files[0:max_num_fr]:

        raw, raw_noise,  expanded_h, expanded_w = open_image(fpath,\
                                                   gray_mode=gray_mode,\
                                                   expand_if_needed=expand_if_needed,\
                                                   expand_axis0=False)
        
        raw = rggb_raw(raw)
        raw_noise = rggb_raw(raw_noise)


        seq_list_raw.append(raw)
        seq_list_raw_noise.append(raw_noise)
    seq_raw = np.stack(seq_list_raw, axis=0)
    seq_raw_noise = np.stack(seq_list_raw_noise, axis=0)
    return seq_raw, seq_raw_noise,  expanded_h, expanded_w

def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True):
    r""" Opens an image and expands it if necesary
    Args:
        fpath: string, path of image file
        gray_mode: boolean, True indicating if image is to be open
            in grayscale mode
        expand_if_needed: if True, the spatial dimensions will be expanded if
            size is odd
        expand_axis0: if True, output will have a fourth dimension
    Returns:
        img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even.
            if expand_axis0=False, the output will have a shape CxHxW.
            The image gets normalized to the range [0, 1].
        expanded_h: True if original dim H was odd and image got expanded in this dimension.
        expanded_w: True if original dim W was odd and image got expanded in this dimension.
    """
    # if not gray_mode:
    #     # Open image as a CxHxW torch.Tensor
    #     img = cv2.imread(fpath)
    #     # from HxWxC to CxHxW, RGB image
    #     img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
    # else:
    #     # from HxWxC to  CxHxW grayscale image (C=1)
    #     img = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)



    # 测试真实的图片
    # raw_img = ((np.fromfile(fpath,np.uint16).astype(np.float32))*4833)/2048
    # raw_img = np.clip(raw_img-64, 0, 1023-64)
    # raw_img = raw_img.reshape((3000,4000))

    # raw_img = np.load(fpath).astype(np.float32)-64
    w = 4000
    h = 3000
    raw_img = np.fromfile(fpath,dtype=np.uint16,count=w*h)
    raw_img = raw_img.reshape((h,w)).astype(np.float32)-64
    raw_img = np.clip(raw_img, 0, 959)

    noise_fpath =fpath.replace('onlyraw_test_clean_raw','onlyraw_test_noise_raw')
    raw_img_noise = np.fromfile(noise_fpath,dtype=np.uint16,count=w*h)
    raw_img_noise = raw_img_noise.reshape((h,w)).astype(np.float32)-64
    raw_img_noise = np.clip(raw_img_noise, 0, 959)


    #blc


    # if expand_axis0:
    #     img = np.expand_dims(img, 0)

    # Handle odd sizes
    expanded_h = False
    expanded_w = False
    sh_im = raw_img.shape
    # if expand_if_needed:
    #     if sh_im[-2]%2 == 1:
    #         expanded_h = True
    #         if expand_axis0:
    #             img = np.concatenate((img, \
    #                 img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
    #         else:
    #             img = np.concatenate((img, \
    #                 img[:, -1, :][:, np.newaxis, :]), axis=1)


    #     if sh_im[-1]%2 == 1:
    #         expanded_w = True
    #         if expand_axis0:
    #             img = np.concatenate((img, \
    #                 img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
    #         else:
    #             img = np.concatenate((img, \
    #                 img[:, :, -1][:, :, np.newaxis]), axis=2)

    if normalize_data:
        raw_img = normalize(raw_img)
        raw_img_noise = normalize(raw_img_noise)
    return raw_img, raw_img_noise,  expanded_h, expanded_w


def normalize(data):
    r"""Normalizes a unit8 image to a float32 image in the range [0, 1]

    Args:
        data: a unint8 numpy array to normalize from [0, 255] to [0, 1]
    """
    return np.float32(data/(959))


def augment_cuda(batches, args, spynet=None):

    def _augment(img, hflip=True, rot=True):

        hflip = hflip and random.random() < 0.5
        vflip = rot and random.random() < 0.5
        # rot90 = rot and random.random() < 0.5
        k1 = np.random.randint(0, 4)  #0,1,2,3
        if hflip: img = img.flip(-1)
        if vflip: img = img.flip(-2)        
        
        img = torch.rot90(img, k=k1, dims=[-2, -1])
        
        return img

    batches_aug = _augment(batches)

    if  args.pair:
        noise = batches_aug[:,:args.frame,...]/959
        clean = batches_aug[:,args.frame,...]/959 #if args.scene != 'noisedata' else  batches_aug[:,args.frame,...]

 
    else:
        clean, noise = Noise_simulation(batches_aug,args)
        if not args.consistent_loss:
            clean = clean[:, args.frame // 2, ...]
    B, F, C , H, W = noise.shape
    noise = noise.reshape(B, F*C , H, W )


    return  clean, noise, None
    

def Noise_simulation(batches_aug,args):
    batches_aug = batches_aug/959
    batches_aug = torch.clamp(batches_aug , 0, 1)
    B = batches_aug.shape[0]
    batch_aug_mean = batches_aug.mean(dim=(1,2,3,4))
    if args.need_Scaling:
        if args.sample_gain == 'type1':
            # rand_avg =  torch.randint(args.luminance_low, args.luminance_high ,(B, )).cuda(args.local_rank)
            rand_avg =  (torch.rand((B)) * 0.12  + 0.001).cuda(args.local_rank)
        if args.sample_gain == 'type2':
            rand_avg = Gain_Sampler(B).cuda(args.local_rank)

        coef = (batch_aug_mean / rand_avg).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        batch_aug_dark = torch.clamp(batches_aug / coef, 0, 1)
    else:
        batch_aug_dark = batches_aug

    a,b, again, dgain = random_noise_levels_nightimaging(B, args)
    batch_aug_dark,batch_aug_dark_noise = add_noise(args, batch_aug_dark,a.cuda(args.local_rank),b.cuda(args.local_rank),dgain.cuda(args.local_rank))

    batch_aug_dark_noise = torch.clamp(batch_aug_dark_noise, -0.1, 1)

    # print(batch_aug_dark_noise.mean())
    return batch_aug_dark.float(), batch_aug_dark_noise.float()

def random_noise_levels_nightimaging(B, args):
    # print('use new')
    g = torch.FloatTensor(B).uniform_(0, 125).int().long()
    noise_profile = torch.from_numpy(np.load('/data1/chengqihua/02_code/03_night_photogrphy/nightimage_v1/dataloader/json_all_2nd.npy'))

    a = noise_profile[g,0]
    b = noise_profile[g,1]

    return a, b, 1, 1*torch.ones(1)

def random_noise_levels(B, args):
    ak1=0.05244803
    ak2=0.01498041
    bk1=0.00648923
    bk2= 0.05899386
    bk3 = 0.21520193
    g = torch.FloatTensor(B).uniform_(args.min_gain, args.max_gain)

    maskA = g > 16

    again = g.clone()
    again[maskA] = 16

    maskB = g < 16

    dgain = g.clone() / 16
    dgain[maskB] = 1



    a = ak1 * again + ak2
    b = bk1 * again*again + bk2* again + bk3

    return a, b, again, dgain

def add_noise(args, image, a, b, dgain):

    dgain = dgain.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
    a = a.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
    b = b.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
    
    
    B, F, C, H, W = image.size()

    image = image / dgain


    poisson_noisy_img = torch.poisson(image/a)*a

    gaussian_noise = torch.sqrt(b)*torch.randn(B, F, C, H, W).cuda(args.local_rank)

    noiseimg = poisson_noisy_img + gaussian_noise

    if args.usedgain :
        noiseimg = noiseimg * dgain
        image = image * dgain
    return image, noiseimg



def normalize_augment(datain):
    '''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \
        [N, num_frames*C. H, W] in  [0., 1.]. It also returns the central (edited by cjm : now all frames) frame of the temporal \
        patch as a ground truth.
    '''
    def transform(sample):
        # define transformations
        do_nothing = lambda x: x
        do_nothing.__name__ = 'do_nothing'
        flipud = lambda x: torch.flip(x, dims=[2])
        flipud.__name__ = 'flipup'
        rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3])
        rot90.__name__ = 'rot90'
        rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2])
        rot90_flipud.__name__ = 'rot90_flipud'
        rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3])
        rot180.__name__ = 'rot180'
        rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2])
        rot180_flipud.__name__ = 'rot180_flipud'
        rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3])
        rot270.__name__ = 'rot270'
        rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2])
        rot270_flipud.__name__ = 'rot270_flipud'
        add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \
                                 std=(5/255.)).expand_as(x).to(x.device)
        add_csnt.__name__ = 'add_csnt'

        # define transformations and their frequency, then pick one.
        aug_list = [do_nothing, flipud, rot90, rot90_flipud, \
                    rot180, rot180_flipud, rot270, rot270_flipud, add_csnt]
        w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] # one fourth chances to do_nothing
        transf = choices(aug_list, w_aug)

        # transform all images in array
        return transf[0](sample)

    img_train = datain   #torch.Size([8, 11, 3, 96, 96])
    # convert to [N, num_frames*C. H, W] in  [0., 1.] from [N, num_frames, C. H, W] in [0., 255.]
    N, F, C, H, W = img_train.shape
    img_train = img_train.view(img_train.size()[0], -1, \
                               img_train.size()[-2], img_train.size()[-1]) / 255.    # torch.Size([8, 33, 96, 96])

    #augment
    img_train = transform(img_train)
    img_train = img_train.view(N, F, C, H, W)
    # extract ground truth (central frame)
    # gt_train = img_train[:, 3*ctrl_fr_idx:3*ctrl_fr_idx+3, :, :]
    return img_train, img_train

def Gain_Sampler(B):
    gain_dict = {
          'low':[5,35],
          'mid':[35,60],
          'high':[60,100]
     }

    level = ['low','mid','high']
    sampled = np.random.choice(level,B,[0.7,0.2,0.1])
    all = []
    for index in sampled:
        all.append(torch.randint(gain_dict[index][0],gain_dict[index][1],(1,)))

    return torch.Tensor(all)

def path_replace(path,args):
    for i in range(len(args.replace_left)):
        path = path.replace(args.replace_left[i],args.replace_right[i])
    return path