|
""" |
|
Different utilities such as orthogonalization of weights, initialization of |
|
loggers, etc |
|
|
|
Copyright (C) 2018, Matias Tassano <[email protected]> |
|
|
|
This program is free software: you can use, modify and/or |
|
redistribute it under the terms of the GNU General Public |
|
License as published by the Free Software Foundation, either |
|
version 3 of the License, or (at your option) any later |
|
version. You should have received a copy of this license along |
|
this program. If not, see <http://www.gnu.org/licenses/>. |
|
""" |
|
import numpy as np |
|
import cv2 |
|
|
|
|
|
def variable_to_cv2_image(varim): |
|
r"""Converts a torch.autograd.Variable to an OpenCV image |
|
|
|
Args: |
|
varim: a torch.autograd.Variable |
|
""" |
|
nchannels = varim.size()[1] |
|
if nchannels == 1: |
|
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8) |
|
elif nchannels == 3: |
|
res = varim.data.cpu().numpy()[0] |
|
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR) |
|
res = (res*255.).clip(0, 255).astype(np.uint8) |
|
else: |
|
raise Exception('Number of color channels not supported') |
|
return res |
|
|
|
|
|
def normalize(data): |
|
return np.float32(data/255.) |
|
|
|
def remove_dataparallel_wrapper(state_dict): |
|
r"""Converts a DataParallel model to a normal one by removing the "module." |
|
wrapper in the module dictionary |
|
|
|
Args: |
|
state_dict: a torch.nn.DataParallel state dictionary |
|
""" |
|
from collections import OrderedDict |
|
|
|
new_state_dict = OrderedDict() |
|
for k, vl in state_dict.items(): |
|
name = k[7:] |
|
new_state_dict[name] = vl |
|
|
|
return new_state_dict |
|
|
|
def is_rgb(im_path): |
|
r""" Returns True if the image in im_path is an RGB image |
|
""" |
|
from skimage.io import imread |
|
rgb = False |
|
im = imread(im_path) |
|
if (len(im.shape) == 3): |
|
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])): |
|
rgb = True |
|
print("rgb: {}".format(rgb)) |
|
print("im shape: {}".format(im.shape)) |
|
return rgb |
|
|