supundhananjaya commited on
Commit
710f982
·
verified ·
1 Parent(s): bcd9f19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -4
app.py CHANGED
@@ -10,13 +10,102 @@ from typing import Dict
10
  import functools
11
  import inspect
12
  from types import SimpleNamespace
13
- import torch
14
  from torch.utils.data import Dataset
15
  from torchvision import transforms
16
  import rasterio
17
  from pathlib import Path
18
  from torchvision.transforms import ToPILImage
19
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class UAHiRISEDataset(Dataset):
22
  def __init__(self, root, stage, transform=None):
@@ -457,13 +546,49 @@ class DDIMScheduler():
457
  def __len__(self):
458
  return self.config.num_train_timesteps
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- def dummy_model(img):
462
  img_array = np.array(img)
463
  return img_array
464
 
465
  iface = gr.Interface(
466
- fn=dummy_model,
467
  inputs="image",
468
  outputs="image"
469
  )
 
10
  import functools
11
  import inspect
12
  from types import SimpleNamespace
 
13
  from torch.utils.data import Dataset
14
  from torchvision import transforms
15
  import rasterio
16
  from pathlib import Path
17
  from torchvision.transforms import ToPILImage
18
+ from base64 import b64encode
19
+ import gc
20
+ from datasets import load_dataset
21
+ import torchvision
22
+ import torch.nn.functional as F
23
+ from IPython.display import HTML
24
+ from matplotlib import pyplot as plt
25
+ from pathlib import Path
26
+ from torch import autocast
27
+ from torchvision import transforms as tfms
28
+ from tqdm.auto import tqdm
29
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
30
+ import os
31
+ import csv
32
+ from torchvision.utils import save_image
33
+ import torch
34
+ import cv2
35
+ from PIL import Image
36
+ import os
37
+ from django.conf import settings
38
+ import torch.nn.functional as F
39
+ import os
40
+ import torch
41
+ from transformers import AutoImageProcessor, SwinModel
42
+ from diffusers import UNet2DConditionModel
43
+
44
+ def load_models():
45
+ torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
46
+
47
+ image_processor_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'image_processor')
48
+ swin_transformer_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'swin_transformer')
49
+ vae_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'vae', 'MonoChannelVAE.pth')
50
+ unet_model_path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'unet')
51
+
52
+ image_processor = AutoImageProcessor.from_pretrained(image_processor_model_path)
53
+ swin_transformer = SwinModel.from_pretrained(swin_transformer_model_path)
54
+
55
+ vae = Autoencoder()
56
+ vae.load_state_dict(torch.load(vae_model_path, map_location=torch.device('cpu')))
57
+ unet = UNet2DConditionModel.from_pretrained(unet_model_path)
58
+ scheduler = DDIMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule='linear',
59
+ num_train_timesteps=1000)
60
+
61
+ vae = vae.to(torch_device)
62
+ swin_transformer = swin_transformer.to(torch_device)
63
+ unet = unet.to(torch_device)
64
+
65
+ return image_processor, swin_transformer, vae, unet, scheduler
66
+
67
+ def tensor_to_latent(input_im,vae):
68
+ with torch.no_grad():
69
+ latent = vae.encoder(input_im)
70
+ return latent
71
+
72
+ def latent_to_tensor(input_im,vae):
73
+ with torch.no_grad():
74
+ images = vae.decoder(input_im)
75
+ return images
76
+
77
+ def upscale_resolution(image):
78
+ sr = cv2.dnn_superres.DnnSuperResImpl_create()
79
+ path = os.path.join(settings.BASE_DIR, 'depthAPI', 'models', 'FSRCNN','FSRCNN_x2.pb')
80
+ sr.readModel(path)
81
+ sr.setModel("fsrcnn",2)
82
+ result = sr.upsample(image)
83
+ resized = cv2.resize(image,dsize=None,fx=2,fy=2)
84
+ img = Image.fromarray(resized.astype('uint8'))
85
+ return img
86
+
87
+ def extract_features(image,torch_device,swin_transformer):
88
+ image.to(torch_device)
89
+ with torch.no_grad():
90
+ swin_output = swin_transformer(**image)
91
+ del image
92
+ image_fea = swin_output.last_hidden_state.squeeze(0)
93
+ return image_fea
94
+
95
+ def rescale(image):
96
+ max_val = torch.max(image)
97
+ min_val = torch.min(image)
98
+
99
+ image = (((image - min_val) / (max_val - min_val)) * 2) - 1
100
+ return image
101
+
102
+ def normalize(x):
103
+ return 2 * (x - x.min()) / (x.max() - x.min()) - 1
104
+
105
+ def upscale_tensor(image):
106
+ output = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
107
+ return output.squeeze(0)
108
+
109
 
110
  class UAHiRISEDataset(Dataset):
111
  def __init__(self, root, stage, transform=None):
 
546
  def __len__(self):
547
  return self.config.num_train_timesteps
548
 
549
+
550
+ image_processor, swin_transformer, vae, unet, scheduler = load_models()
551
+
552
+ def MonoGeoDepthModelRun(image):
553
+ batch_size=1
554
+ torch_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
555
+
556
+
557
+
558
+ image = image.convert("RGB")
559
+ extracted_image = image_processor(image, return_tensors="pt")
560
+ image_embeddings = extract_features(extracted_image, torch_device, swin_transformer)
561
+ image_embeddings = image_embeddings.unsqueeze(0)
562
+
563
+ torch.manual_seed(0)
564
+ random_noise = normalize(torch.randn(1, 1, 512, 512).to(torch_device))
565
+
566
+ image_embeddings = image_embeddings.to(torch_device)
567
+
568
+ with torch.no_grad():
569
+ noisy_latents = tensor_to_latent(random_noise, vae)
570
+ del random_noise
571
+ t = torch.tensor(1000)
572
+ model_input = scheduler.scale_model_input(noisy_latents, t)
573
+ noise_pred = unet(model_input, t, encoder_hidden_states=image_embeddings, return_dict=False)
574
+ noisy_latents = model_input - noise_pred[0]
575
+ predicted_dtm = latent_to_tensor(noisy_latents, vae)
576
+ predicted_dtm = predicted_dtm.detach().cpu()
577
+
578
+ image_ = predicted_dtm.squeeze(0)
579
+ image_ = (image_ - image_.min()) / (image_.max() - image_.min())
580
+
581
+ to_pil = ToPILImage()
582
+ predicted_dtm = to_pil(image_)
583
+
584
+ return predicted_dtm
585
 
586
+ def model(img):
587
  img_array = np.array(img)
588
  return img_array
589
 
590
  iface = gr.Interface(
591
+ fn=model,
592
  inputs="image",
593
  outputs="image"
594
  )