File size: 5,700 Bytes
074c857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import glob
import os
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import model_io
import utils
from adabins import UnetAdaptiveBins
def _is_pil_image(img):
return isinstance(img, Image.Image)
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
class ToTensor(object):
def __init__(self):
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def __call__(self, image, target_size=(640, 480)):
# image = image.resize(target_size)
image = self.to_tensor(image)
image = self.normalize(image)
return image
def to_tensor(self, pic):
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError(
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
return img
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float()
else:
return img
class InferenceHelper:
def __init__(self, models_path, dataset='nyu', device='cuda:0'):
self.toTensor = ToTensor()
self.device = device
if dataset == 'nyu':
self.min_depth = 1e-3
self.max_depth = 10
self.saving_factor = 1000 # used to save in 16 bit
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth)
pretrained_path = os.path.join(models_path,'AdaBins_nyu.pt')
elif dataset == 'kitti':
self.min_depth = 1e-3
self.max_depth = 80
self.saving_factor = 256
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth)
pretrained_path = "./models/AdaBins_kitti.pt"
else:
raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset))
model, _, _ = model_io.load_checkpoint(pretrained_path, model)
model.eval()
self.model = model.to(self.device)
@torch.no_grad()
def predict_pil(self, pil_image, visualized=False):
# pil_image = pil_image.resize((640, 480))
img = np.asarray(pil_image) / 255.
img = self.toTensor(img).unsqueeze(0).float().to(self.device)
bin_centers, pred = self.predict(img)
if visualized:
viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma')
# pred = np.asarray(pred*1000, dtype='uint16')
viz = Image.fromarray(viz)
return bin_centers, pred, viz
return bin_centers, pred
@torch.no_grad()
def predict(self, image):
bins, pred = self.model(image)
pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth)
# Flip
image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device)
pred_lr = self.model(image)[-1]
pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth)
# Take average of original and mirror
final = 0.5 * (pred + pred_lr)
final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:],
mode='bilinear', align_corners=True).cpu().numpy()
final[final < self.min_depth] = self.min_depth
final[final > self.max_depth] = self.max_depth
final[np.isinf(final)] = self.max_depth
final[np.isnan(final)] = self.min_depth
centers = 0.5 * (bins[:, 1:] + bins[:, :-1])
centers = centers.cpu().squeeze().numpy()
centers = centers[centers > self.min_depth]
centers = centers[centers < self.max_depth]
return centers, final
@torch.no_grad()
def predict_dir(self, test_dir, out_dir):
os.makedirs(out_dir, exist_ok=True)
transform = ToTensor()
all_files = glob.glob(os.path.join(test_dir, "*"))
self.model.eval()
for f in tqdm(all_files):
image = np.asarray(Image.open(f), dtype='float32') / 255.
image = transform(image).unsqueeze(0).to(self.device)
centers, final = self.predict(image)
# final = final.squeeze().cpu().numpy()
final = (final * self.saving_factor).astype('uint16')
basename = os.path.basename(f).split('.')[0]
save_path = os.path.join(out_dir, basename + ".png")
Image.fromarray(final.squeeze()).save(save_path)
if __name__ == '__main__':
import matplotlib.pyplot as plt
from time import time
img = Image.open("test_imgs/classroom__rgb_00283.jpg")
start = time()
inferHelper = InferenceHelper()
centers, pred = inferHelper.predict_pil(img)
print(f"took :{time() - start}s")
plt.imshow(pred.squeeze(), cmap='magma_r')
plt.show()
|