|
import torch |
|
from torch.utils.data import Dataset |
|
import os |
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import random |
|
from utils.geo import BoundaryBox, Projection |
|
from osm.tiling import TileManager,MapTileManager |
|
from pathlib import Path |
|
from torchvision import transforms |
|
from torch.utils.data import DataLoader |
|
|
|
class UavMapPair(Dataset): |
|
def __init__( |
|
self, |
|
root: Path, |
|
city:str, |
|
training:bool, |
|
transform |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.uav_image_path = root/city/'uav' |
|
self.map_path = root/city/'map' |
|
self.map_vis = root / city / 'map_vis' |
|
info_path = root / city / 'info.csv' |
|
|
|
self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) |
|
|
|
self.transform=transform |
|
self.training=training |
|
|
|
def random_center_crop(self,image): |
|
height, width = image.shape[:2] |
|
|
|
|
|
crop_size = random.randint(min(height, width) // 2, min(height, width)) |
|
|
|
|
|
start_x = (width - crop_size) // 2 |
|
start_y = (height - crop_size) // 2 |
|
|
|
|
|
cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size] |
|
|
|
return cropped_image |
|
def __getitem__(self, index: int): |
|
id, uav_name, map_name, \ |
|
uav_long, uav_lat, \ |
|
map_long, map_lat, \ |
|
tile_size_meters, pixel_per_meter, \ |
|
u, v, yaw,dis=self.info[index] |
|
|
|
|
|
uav_image=cv2.imread(str(self.uav_image_path/uav_name)) |
|
if self.training: |
|
uav_image =self.random_center_crop(uav_image) |
|
uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB) |
|
if self.transform: |
|
uav_image=self.transform(uav_image) |
|
map=np.load(str(self.map_path/map_name)) |
|
|
|
return { |
|
'map':torch.from_numpy(np.ascontiguousarray(map)).long(), |
|
'image':torch.tensor(uav_image), |
|
'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(), |
|
'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(), |
|
"uv":torch.tensor([float(u), float(v)]).float(), |
|
} |
|
def __len__(self): |
|
return len(self.info) |
|
if __name__ == '__main__': |
|
|
|
root=Path('/root/DATASET/OrienterNet/UavMap/') |
|
city='NewYork' |
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize(256), |
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|
]) |
|
|
|
dataset=UavMapPair( |
|
root=root, |
|
city=city, |
|
transform=transform |
|
) |
|
datasetloder = DataLoader(dataset, batch_size=3) |
|
for batch, i in enumerate(datasetloder): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|