File size: 1,997 Bytes
3fce28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Optional, Union
from tqdm.auto import trange
from PIL import ImageOps
from PIL import Image
from torch import nn
import numpy as np
import torch
import cv2


class MidasDepth(nn.Module):
    def __init__(self, model_type="DPT_Large",
                 device=torch.device(
                     "cuda" if torch.cuda.is_available() else "cpu"),
                 is_inpainting=False):
        super().__init__()
        self.device = device
        if self.device.type == "mps":
            self.device = torch.device("cpu")
        self.model = torch.hub.load(
            "intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
        self.transform = torch.hub.load(
            "intel-isl/MiDaS", "transforms").dpt_transform

    @torch.no_grad()
    def forward(self, image):
        if torch.is_tensor(image):
            image = image.cpu().detach()
        if not isinstance(image, np.ndarray):
            image = np.asarray(image)
        image = image.squeeze()
        batch = self.transform(image).to(self.device)
        prediction = self.model(batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=image.shape[-3:-1],
            mode="bicubic",
            align_corners=False,
        )[:, 0]
        # prediction = prediction - prediction.min() + 1.5
        # prediction = 20 / prediction
        return prediction  # .squeeze()

    @torch.no_grad()
    def get_depth(self, img):
        im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255.
        og_depth = self(im.unsqueeze(0) * 255.)[0]
        d = og_depth
        d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3
        d = 30 / d
        # d = d.max() - d
        # d = d / d.max() * 15
        # d = d + 1.5
        return d.detach().cpu().numpy()


if __name__ == "__main__":
    from matplotlib import pyplot as plt
    plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg")))
    plt.show()