supundhananjaya commited on
Commit
bcd9f19
·
verified ·
1 Parent(s): 0e2a983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py CHANGED
@@ -10,7 +10,57 @@ from typing import Dict
10
  import functools
11
  import inspect
12
  from types import SimpleNamespace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Autoencoder(nn.Module):
15
  def __init__(self):
16
  super().__init__()
 
10
  import functools
11
  import inspect
12
  from types import SimpleNamespace
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+ from torchvision import transforms
16
+ import rasterio
17
+ from pathlib import Path
18
+ from torchvision.transforms import ToPILImage
19
+ import numpy as np
20
+
21
+ class UAHiRISEDataset(Dataset):
22
+ def __init__(self, root, stage, transform=None):
23
+ self.root = Path(root)
24
+ self.stage = stage
25
+ self.transform = transform
26
+ self.filenames = self._read_split()
27
+
28
+ def __len__(self):
29
+ return len(self.filenames)
30
+
31
+ def __getitem__(self, idx):
32
+ filename = self.filenames[idx]
33
+ raster_path = self.root / filename
34
+
35
+ raster = rasterio.open(raster_path)
36
+
37
+ left = raster.read(1).astype('uint8')
38
+ dtm = raster.read(2)
39
 
40
+ # converting absolute heigths to relative depths
41
+ dtm = abs(dtm - dtm.min())
42
+
43
+ to_pil = ToPILImage()
44
+
45
+ to_transform = {"image": to_pil(left).convert('RGB'), "dtm": dtm}
46
+
47
+ return self.transform(to_transform)
48
+ # return to_transform
49
+
50
+ def _add_channels(self, image):
51
+ img_expanded = np.stack([image, image, image], axis=-1)
52
+ img_tensor = torch.from_numpy(img_expanded).permute(2, 0, 1)
53
+ return img_tensor
54
+
55
+ def set_transform(self, transform):
56
+ self.transform = transform
57
+
58
+ def _read_split(self):
59
+ split_filename = f'uahirise_{self.stage}.txt'
60
+ split_filepath = Path(f'filenames/{split_filename}')
61
+ filenames = split_filepath.read_text().splitlines()
62
+ return filenames
63
+
64
  class Autoencoder(nn.Module):
65
  def __init__(self):
66
  super().__init__()