GSASR / utils /split_and_joint_image.py
mt-cly
init
909940e
import torch
import torch.nn.functional as F
import math
from utils.gaussian_splatting import generate_2D_gaussian_splatting_step, generate_2D_gaussian_splatting_step_buffer
### If the GPU memory is limited, please use the following code to do tiling process for input LR image
# def split_and_joint_image(lq, scale_factor, model_g, model_fea2gs, scale_modify, split_size = 48,
# overlap_size = 8,
# crop_size = 4,
# default_step_size = 1.2, mode = 'scale_modify',
# cuda_rendering = True,
# if_dmax = False,
# dmax_mode = 'fix',
# dmax = 0.1):
# h_lq, w_lq = lq.shape[-2:]
# assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
# tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
# tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
# pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
# pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
# lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
# split_size_sr = math.ceil(split_size * scale_factor)
# sr_tile_list = []
# for h_num in range(tile_nums_h):
# for w_num in range(tile_nums_w):
# tile_lq_position_start_h = h_num * (split_size - overlap_size)
# tile_lq_position_start_w = w_num * (split_size - overlap_size)
# tile_lq_position_end_h = tile_lq_position_start_h + split_size
# tile_lq_position_end_w = tile_lq_position_start_w + split_size
# input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
# model_g_output = model_g(input_tile)
# scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
# batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
# gs_parameters = batch_gs_parameters[0, :]
# b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
# lq=input_tile[0, :], scale=scale_factor, sample_coords=None,
# scale_modify = scale_modify,
# default_step_size = default_step_size, mode = mode,
# cuda_rendering = cuda_rendering,
# if_dmax = if_dmax,
# dmax_mode = dmax_mode,
# dmax = dmax)
# sr_tile_list.append(b_output.unsqueeze(0))
# tile_sr_h = sr_tile_list[0].shape[2]
# tile_sr_w = sr_tile_list[0].shape[3]
# assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
# f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
# overlap_sr = math.ceil(overlap_size * scale_factor)
# sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
# math.ceil(lq_pad.shape[2] * scale_factor),
# math.ceil(lq_pad.shape[3] * scale_factor),
# device=lq.device)
# idx = 0
# for h_num in range(tile_nums_h):
# for w_num in range(tile_nums_w):
# tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
# tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
# tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
# tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
# if h_num == 0 and w_num == 0:
# sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
# tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
# elif h_num == 0 and w_num !=0:
# sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
# tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
# elif h_num != 0 and w_num ==0:
# sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
# tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
# else:
# sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
# tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
# idx = idx + 1
# print(f"sr_pad shape is {sr_pad.shape}")
# # sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
# sr_final = sr_pad
# return sr_final
def split_and_joint_image(lq, scale_factor, split_size,
overlap_size, model_g, model_fea2gs,
scale_modify, crop_size = 2,
default_step_size = 1.2, mode = 'scale_modify',
cuda_rendering = True,
if_dmax = False,
dmax_mode = 'fix',
dmax = 25):
h_lq, w_lq = lq.shape[-2:]
# assert h_lq > split_size, f'h_lq-{h_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
# assert w_lq > split_size, f'w_lq-{w_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
assert pad_h_lq < h_lq, f'pad_h_lq-{pad_h_lq} should be smaller than h_lq-{h_lq}, please decrease the split_size-{split_size}'
assert pad_w_lq < w_lq, f'pad_w_lq-{pad_w_lq} should be smaller than w_lq-{w_lq}, please decrease the split_size-{split_size}'
lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
# lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='constant', value=0)
split_size_sr = math.ceil(split_size * scale_factor)
sr_tile_list = []
for h_num in range(tile_nums_h):
for w_num in range(tile_nums_w):
tile_lq_position_start_h = h_num * (split_size - overlap_size)
tile_lq_position_start_w = w_num * (split_size - overlap_size)
tile_lq_position_end_h = tile_lq_position_start_h + split_size
tile_lq_position_end_w = tile_lq_position_start_w + split_size
input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
model_g_output = model_g(input_tile)
scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
gs_parameters = batch_gs_parameters[0, :]
b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
scale=scale_factor, sample_coords=None,
scale_modify = scale_modify,
default_step_size = default_step_size, mode = mode,
cuda_rendering = cuda_rendering,
if_dmax = if_dmax,
dmax_mode = dmax_mode,
dmax = dmax)
sr_tile_list.append(b_output.unsqueeze(0))
tile_sr_h = sr_tile_list[0].shape[2]
tile_sr_w = sr_tile_list[0].shape[3]
assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
overlap_sr = math.ceil(overlap_size * scale_factor)
sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
(tile_nums_h - 1) * (split_size_sr - overlap_sr) + split_size_sr,
(tile_nums_w - 1) * (split_size_sr - overlap_sr) + split_size_sr,
device=lq.device)
idx = 0
if scale_factor != int(scale_factor):
for h_num in range(tile_nums_h):
for w_num in range(tile_nums_w):
tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
if h_num == 0 and w_num == 0:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
elif h_num == 0 and w_num !=0:
if w_num != tile_nums_w - 1:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
else:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
elif h_num != 0 and w_num ==0:
if h_num != tile_nums_h - 1:
sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
else:
sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
else:
if w_num != tile_nums_w - 1 and h_num != tile_nums_h - 1:
sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
elif w_num == tile_nums_w - 1 and h_num != tile_nums_h - 1:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
elif w_num != tile_nums_w - 1 and h_num == tile_nums_h - 1:
sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
elif w_num == tile_nums_w - 1 and h_num == tile_nums_h - 1:
sr_pad[:,:,tile_sr_position_start_h+crop_size:sr_pad.shape[2],
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
idx = idx + 1
else:
for h_num in range(tile_nums_h):
for w_num in range(tile_nums_w):
tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
if h_num == 0 and w_num == 0:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
elif h_num == 0 and w_num !=0:
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
elif h_num != 0 and w_num ==0:
sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
else:
sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
idx = idx + 1
print(f"sr_pad shape is {sr_pad.shape}")
# sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
sr_final = sr_pad
return sr_final