|
import glob |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
|
|
import model_io |
|
import utils |
|
from adabins import UnetAdaptiveBins |
|
|
|
|
|
def _is_pil_image(img): |
|
return isinstance(img, Image.Image) |
|
|
|
|
|
def _is_numpy_image(img): |
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) |
|
|
|
|
|
class ToTensor(object): |
|
def __init__(self): |
|
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
def __call__(self, image, target_size=(640, 480)): |
|
|
|
image = self.to_tensor(image) |
|
image = self.normalize(image) |
|
return image |
|
|
|
def to_tensor(self, pic): |
|
if not (_is_pil_image(pic) or _is_numpy_image(pic)): |
|
raise TypeError( |
|
'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
|
|
|
if isinstance(pic, np.ndarray): |
|
img = torch.from_numpy(pic.transpose((2, 0, 1))) |
|
return img |
|
|
|
|
|
if pic.mode == 'I': |
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
|
elif pic.mode == 'I;16': |
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
|
else: |
|
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) |
|
|
|
if pic.mode == 'YCbCr': |
|
nchannel = 3 |
|
elif pic.mode == 'I;16': |
|
nchannel = 1 |
|
else: |
|
nchannel = len(pic.mode) |
|
img = img.view(pic.size[1], pic.size[0], nchannel) |
|
|
|
img = img.transpose(0, 1).transpose(0, 2).contiguous() |
|
if isinstance(img, torch.ByteTensor): |
|
return img.float() |
|
else: |
|
return img |
|
|
|
|
|
class InferenceHelper: |
|
def __init__(self, models_path, dataset='nyu', device='cuda:0'): |
|
self.toTensor = ToTensor() |
|
self.device = device |
|
if dataset == 'nyu': |
|
self.min_depth = 1e-3 |
|
self.max_depth = 10 |
|
self.saving_factor = 1000 |
|
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) |
|
pretrained_path = os.path.join(models_path,'AdaBins_nyu.pt') |
|
elif dataset == 'kitti': |
|
self.min_depth = 1e-3 |
|
self.max_depth = 80 |
|
self.saving_factor = 256 |
|
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) |
|
pretrained_path = "./models/AdaBins_kitti.pt" |
|
else: |
|
raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) |
|
|
|
model, _, _ = model_io.load_checkpoint(pretrained_path, model) |
|
model.eval() |
|
self.model = model.to(self.device) |
|
|
|
@torch.no_grad() |
|
def predict_pil(self, pil_image, visualized=False): |
|
|
|
img = np.asarray(pil_image) / 255. |
|
|
|
img = self.toTensor(img).unsqueeze(0).float().to(self.device) |
|
bin_centers, pred = self.predict(img) |
|
|
|
if visualized: |
|
viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') |
|
|
|
viz = Image.fromarray(viz) |
|
return bin_centers, pred, viz |
|
return bin_centers, pred |
|
|
|
@torch.no_grad() |
|
def predict(self, image): |
|
bins, pred = self.model(image) |
|
pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) |
|
|
|
|
|
image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) |
|
pred_lr = self.model(image)[-1] |
|
pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) |
|
|
|
|
|
final = 0.5 * (pred + pred_lr) |
|
final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], |
|
mode='bilinear', align_corners=True).cpu().numpy() |
|
|
|
final[final < self.min_depth] = self.min_depth |
|
final[final > self.max_depth] = self.max_depth |
|
final[np.isinf(final)] = self.max_depth |
|
final[np.isnan(final)] = self.min_depth |
|
|
|
centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) |
|
centers = centers.cpu().squeeze().numpy() |
|
centers = centers[centers > self.min_depth] |
|
centers = centers[centers < self.max_depth] |
|
|
|
return centers, final |
|
|
|
@torch.no_grad() |
|
def predict_dir(self, test_dir, out_dir): |
|
os.makedirs(out_dir, exist_ok=True) |
|
transform = ToTensor() |
|
all_files = glob.glob(os.path.join(test_dir, "*")) |
|
self.model.eval() |
|
for f in tqdm(all_files): |
|
image = np.asarray(Image.open(f), dtype='float32') / 255. |
|
image = transform(image).unsqueeze(0).to(self.device) |
|
|
|
centers, final = self.predict(image) |
|
|
|
|
|
final = (final * self.saving_factor).astype('uint16') |
|
basename = os.path.basename(f).split('.')[0] |
|
save_path = os.path.join(out_dir, basename + ".png") |
|
|
|
Image.fromarray(final.squeeze()).save(save_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
import matplotlib.pyplot as plt |
|
from time import time |
|
|
|
img = Image.open("test_imgs/classroom__rgb_00283.jpg") |
|
start = time() |
|
inferHelper = InferenceHelper() |
|
centers, pred = inferHelper.predict_pil(img) |
|
print(f"took :{time() - start}s") |
|
plt.imshow(pred.squeeze(), cmap='magma_r') |
|
plt.show() |
|
|