Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,13 +10,102 @@ from typing import Dict
|
|
10 |
import functools
|
11 |
import inspect
|
12 |
from types import SimpleNamespace
|
13 |
-
import torch
|
14 |
from torch.utils.data import Dataset
|
15 |
from torchvision import transforms
|
16 |
import rasterio
|
17 |
from pathlib import Path
|
18 |
from torchvision.transforms import ToPILImage
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class UAHiRISEDataset(Dataset):
|
22 |
def __init__(self, root, stage, transform=None):
|
@@ -457,13 +546,49 @@ class DDIMScheduler():
|
|
457 |
def __len__(self):
|
458 |
return self.config.num_train_timesteps
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
-
def
|
462 |
img_array = np.array(img)
|
463 |
return img_array
|
464 |
|
465 |
iface = gr.Interface(
|
466 |
-
fn=
|
467 |
inputs="image",
|
468 |
outputs="image"
|
469 |
)
|
|
|
10 |
import functools
|
11 |
import inspect
|
12 |
from types import SimpleNamespace
|
|
|
13 |
from torch.utils.data import Dataset
|
14 |
from torchvision import transforms
|
15 |
import rasterio
|
16 |
from pathlib import Path
|
17 |
from torchvision.transforms import ToPILImage
|
18 |
+
from base64 import b64encode
|
19 |
+
import gc
|
20 |
+
from datasets import load_dataset
|
21 |
+
import torchvision
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from IPython.display import HTML
|
24 |
+
from matplotlib import pyplot as plt
|
25 |
+
from pathlib import Path
|
26 |
+
from torch import autocast
|
27 |
+
from torchvision import transforms as tfms
|
28 |
+
from tqdm.auto import tqdm
|
29 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
30 |
+
import os
|
31 |
+
import csv
|
32 |
+
from torchvision.utils import save_image
|
33 |
+
import torch
|
34 |
+
import cv2
|
35 |
+
from PIL import Image
|
36 |
+
import os
|
37 |
+
from django.conf import settings
|
38 |
+
import torch.nn.functional as F
|
39 |
+
import os
|
40 |
+
import torch
|
41 |
+
from transformers import AutoImageProcessor, SwinModel
|
42 |
+
from diffusers import UNet2DConditionModel
|
43 |
+
|
44 |
+
def load_models():
|
45 |
+
torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
46 |
+
|
47 |
+
image_processor_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'image_processor')
|
48 |
+
swin_transformer_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'swin_transformer')
|
49 |
+
vae_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'vae', 'MonoChannelVAE.pth')
|
50 |
+
unet_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'unet')
|
51 |
+
|
52 |
+
image_processor = AutoImageProcessor.from_pretrained(image_processor_model_path)
|
53 |
+
swin_transformer = SwinModel.from_pretrained(swin_transformer_model_path)
|
54 |
+
|
55 |
+
vae = Autoencoder()
|
56 |
+
vae.load_state_dict(torch.load(vae_model_path, map_location=torch.device('cpu')))
|
57 |
+
unet = UNet2DConditionModel.from_pretrained(unet_model_path)
|
58 |
+
scheduler = DDIMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule='linear',
|
59 |
+
num_train_timesteps=1000)
|
60 |
+
|
61 |
+
vae = vae.to(torch_device)
|
62 |
+
swin_transformer = swin_transformer.to(torch_device)
|
63 |
+
unet = unet.to(torch_device)
|
64 |
+
|
65 |
+
return image_processor, swin_transformer, vae, unet, scheduler
|
66 |
+
|
67 |
+
def tensor_to_latent(input_im,vae):
|
68 |
+
with torch.no_grad():
|
69 |
+
latent = vae.encoder(input_im)
|
70 |
+
return latent
|
71 |
+
|
72 |
+
def latent_to_tensor(input_im,vae):
|
73 |
+
with torch.no_grad():
|
74 |
+
images = vae.decoder(input_im)
|
75 |
+
return images
|
76 |
+
|
77 |
+
def upscale_resolution(image):
|
78 |
+
sr = cv2.dnn_superres.DnnSuperResImpl_create()
|
79 |
+
path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'FSRCNN','FSRCNN_x2.pb')
|
80 |
+
sr.readModel(path)
|
81 |
+
sr.setModel("fsrcnn",2)
|
82 |
+
result = sr.upsample(image)
|
83 |
+
resized = cv2.resize(image,dsize=None,fx=2,fy=2)
|
84 |
+
img = Image.fromarray(resized.astype('uint8'))
|
85 |
+
return img
|
86 |
+
|
87 |
+
def extract_features(image,torch_device,swin_transformer):
|
88 |
+
image.to(torch_device)
|
89 |
+
with torch.no_grad():
|
90 |
+
swin_output = swin_transformer(**image)
|
91 |
+
del image
|
92 |
+
image_fea = swin_output.last_hidden_state.squeeze(0)
|
93 |
+
return image_fea
|
94 |
+
|
95 |
+
def rescale(image):
|
96 |
+
max_val = torch.max(image)
|
97 |
+
min_val = torch.min(image)
|
98 |
+
|
99 |
+
image = (((image - min_val) / (max_val - min_val)) * 2) - 1
|
100 |
+
return image
|
101 |
+
|
102 |
+
def normalize(x):
|
103 |
+
return 2 * (x - x.min()) / (x.max() - x.min()) - 1
|
104 |
+
|
105 |
+
def upscale_tensor(image):
|
106 |
+
output = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
|
107 |
+
return output.squeeze(0)
|
108 |
+
|
109 |
|
110 |
class UAHiRISEDataset(Dataset):
|
111 |
def __init__(self, root, stage, transform=None):
|
|
|
546 |
def __len__(self):
|
547 |
return self.config.num_train_timesteps
|
548 |
|
549 |
+
|
550 |
+
image_processor, swin_transformer, vae, unet, scheduler = load_models()
|
551 |
+
|
552 |
+
def MonoGeoDepthModelRun(image):
|
553 |
+
batch_size=1
|
554 |
+
torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
555 |
+
|
556 |
+
|
557 |
+
|
558 |
+
image = image.convert("RGB")
|
559 |
+
extracted_image = image_processor(image, return_tensors="pt")
|
560 |
+
image_embeddings = extract_features(extracted_image, torch_device, swin_transformer)
|
561 |
+
image_embeddings = image_embeddings.unsqueeze(0)
|
562 |
+
|
563 |
+
torch.manual_seed(0)
|
564 |
+
random_noise = normalize(torch.randn(1, 1, 512, 512).to(torch_device))
|
565 |
+
|
566 |
+
image_embeddings = image_embeddings.to(torch_device)
|
567 |
+
|
568 |
+
with torch.no_grad():
|
569 |
+
noisy_latents = tensor_to_latent(random_noise, vae)
|
570 |
+
del random_noise
|
571 |
+
t = torch.tensor(1000)
|
572 |
+
model_input = scheduler.scale_model_input(noisy_latents, t)
|
573 |
+
noise_pred = unet(model_input, t, encoder_hidden_states=image_embeddings, return_dict=False)
|
574 |
+
noisy_latents = model_input - noise_pred[0]
|
575 |
+
predicted_dtm = latent_to_tensor(noisy_latents, vae)
|
576 |
+
predicted_dtm = predicted_dtm.detach().cpu()
|
577 |
+
|
578 |
+
image_ = predicted_dtm.squeeze(0)
|
579 |
+
image_ = (image_ - image_.min()) / (image_.max() - image_.min())
|
580 |
+
|
581 |
+
to_pil = ToPILImage()
|
582 |
+
predicted_dtm = to_pil(image_)
|
583 |
+
|
584 |
+
return predicted_dtm
|
585 |
|
586 |
+
def model(img):
|
587 |
img_array = np.array(img)
|
588 |
return img_array
|
589 |
|
590 |
iface = gr.Interface(
|
591 |
+
fn=model,
|
592 |
inputs="image",
|
593 |
outputs="image"
|
594 |
)
|