Traly's picture
init
193c713
import math
import numpy as np
import torch
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to (min, max), image values will be normalized to [0, 1].
For different tensor shapes, this function will have different behaviors:
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use `make_grid` to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.
Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.
Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple): min and max values for clamp.
Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor = _tensor.squeeze(0).squeeze(0)
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise ValueError('Only support 4D, 3D or 2D tensor. '
f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
result = result[0] if len(result) == 1 else result
return result
def plt_tensor_img(tensor, save_path=None):
plt.imshow(tensor2img(tensor))
plt.show()
if save_path:
plt.savefig(save_path)
def plt_tensor_img_one(tensor, t_dim=1):
if isinstance(tensor, list):
tensor = torch.cat(tensor, dim=t_dim)
nums = tensor.shape[t_dim]
mash = math.ceil(math.sqrt(nums))
plt.figure(dpi=300)
plt_range = min(nums, mash ** 2)
for i in range(plt_range):
plt.subplot(mash, mash, i + 1)
if t_dim == 1:
img = tensor2img(tensor[:, i, ...])
elif t_dim == 0:
img = tensor2img(tensor[i, ...])
plt.imshow(img)
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
plt.show()
def plt_img(img, save_path=None):
plt.imshow(img)
plt.show()
if save_path:
plt.savefig(save_path)