|
|
|
from PIL import Image |
|
import numpy as np |
|
from skimage import color |
|
import torch |
|
import torch.nn.functional as F |
|
from IPython import embed |
|
|
|
def load_img(img_path): |
|
out_np = np.asarray(Image.open(img_path)) |
|
if(out_np.ndim==2): |
|
out_np = np.tile(out_np[:,:,None],3) |
|
return out_np |
|
|
|
def resize_img(img, HW=(256,256), resample=3): |
|
return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) |
|
|
|
def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): |
|
|
|
img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) |
|
|
|
img_lab_orig = color.rgb2lab(img_rgb_orig) |
|
img_lab_rs = color.rgb2lab(img_rgb_rs) |
|
|
|
img_l_orig = img_lab_orig[:,:,0] |
|
img_l_rs = img_lab_rs[:,:,0] |
|
|
|
tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] |
|
tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] |
|
|
|
return (tens_orig_l, tens_rs_l) |
|
|
|
def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): |
|
|
|
|
|
|
|
HW_orig = tens_orig_l.shape[2:] |
|
HW = out_ab.shape[2:] |
|
|
|
|
|
if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): |
|
out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') |
|
else: |
|
out_ab_orig = out_ab |
|
|
|
out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) |
|
return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) |
|
|