import glob |
import os |
import numpy as np |
import torch |
import torch.nn as nn |
from PIL import Image |
from torchvision import transforms |
from tqdm import tqdm |
import model_io |
import utils |
from adabins import UnetAdaptiveBins |
def _is_pil_image(img): |
return isinstance(img, Image.Image) |
def _is_numpy_image(img): |
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) |
class ToTensor(object): |
def __init__(self): |
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
def __call__(self, image, target_size=(640, 480)): |
image = self.to_tensor(image) |
image = self.normalize(image) |
return image |
def to_tensor(self, pic): |
if not (_is_pil_image(pic) or _is_numpy_image(pic)): |
raise TypeError( |
'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
if isinstance(pic, np.ndarray): |
img = torch.from_numpy(pic.transpose((2, 0, 1))) |
return img |
if pic.mode == 'I': |
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
elif pic.mode == 'I;16': |
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
else: |
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) |
if pic.mode == 'YCbCr': |
nchannel = 3 |
elif pic.mode == 'I;16': |
nchannel = 1 |
else: |
nchannel = len(pic.mode) |
img = img.view(pic.size[1], pic.size[0], nchannel) |
img = img.transpose(0, 1).transpose(0, 2).contiguous() |
if isinstance(img, torch.ByteTensor): |
return img.float() |
else: |
return img |
class InferenceHelper: |
def __init__(self, models_path, dataset='nyu', device='cuda:0'): |
self.toTensor = ToTensor() |
self.device = device |
if dataset == 'nyu': |
self.min_depth = 1e-3 |
self.max_depth = 10 |
self.saving_factor = 1000 |
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) |
pretrained_path = os.path.join(models_path,'AdaBins_nyu.pt') |
elif dataset == 'kitti': |
self.min_depth = 1e-3 |
self.max_depth = 80 |
self.saving_factor = 256 |
model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) |
pretrained_path = "./models/AdaBins_kitti.pt" |
else: |
raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) |
model, _, _ = model_io.load_checkpoint(pretrained_path, model) |
model.eval() |
self.model = model.to(self.device) |
@torch.no_grad() |
def predict_pil(self, pil_image, visualized=False): |
img = np.asarray(pil_image) / 255. |
img = self.toTensor(img).unsqueeze(0).float().to(self.device) |
bin_centers, pred = self.predict(img) |
if visualized: |
viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') |
viz = Image.fromarray(viz) |
return bin_centers, pred, viz |
return bin_centers, pred |
@torch.no_grad() |
def predict(self, image): |
bins, pred = self.model(image) |
pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) |
image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) |
pred_lr = self.model(image)[-1] |
pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) |
final = 0.5 * (pred + pred_lr) |
final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], |
mode='bilinear', align_corners=True).cpu().numpy() |
final[final < self.min_depth] = self.min_depth |
final[final > self.max_depth] = self.max_depth |
final[np.isinf(final)] = self.max_depth |
final[np.isnan(final)] = self.min_depth |
centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) |
centers = centers.cpu().squeeze().numpy() |
centers = centers[centers > self.min_depth] |
centers = centers[centers < self.max_depth] |
return centers, final |
@torch.no_grad() |
def predict_dir(self, test_dir, out_dir): |
os.makedirs(out_dir, exist_ok=True) |
transform = ToTensor() |
all_files = glob.glob(os.path.join(test_dir, "*")) |
self.model.eval() |
for f in tqdm(all_files): |
image = np.asarray(Image.open(f), dtype='float32') / 255. |
image = transform(image).unsqueeze(0).to(self.device) |
centers, final = self.predict(image) |
final = (final * self.saving_factor).astype('uint16') |
basename = os.path.basename(f).split('.')[0] |
save_path = os.path.join(out_dir, basename + ".png") |
Image.fromarray(final.squeeze()).save(save_path) |
if __name__ == '__main__': |
import matplotlib.pyplot as plt |
from time import time |
img = Image.open("test_imgs/classroom__rgb_00283.jpg") |
start = time() |
inferHelper = InferenceHelper() |
centers, pred = inferHelper.predict_pil(img) |
print(f"took :{time() - start}s") |
plt.imshow(pred.squeeze(), cmap='magma_r') |
plt.show() |