MapLocNet / dataset /UAV /dataset.py
wangerniu
Commit message.
124ba77
import torch
from torch.utils.data import Dataset
import os
import cv2
# @Time : 2023-02-13 22:56
# @Author : Wang Zhen
# @Email : [email protected]
# @File : SatelliteTool.py
# @Project : TGRS_seqmatch_2023_1
import numpy as np
import random
from utils.geo import BoundaryBox, Projection
from osm.tiling import TileManager,MapTileManager
from pathlib import Path
from torchvision import transforms
from torch.utils.data import DataLoader
class UavMapPair(Dataset):
def __init__(
self,
root: Path,
city:str,
training:bool,
transform
):
super().__init__()
# self.root = root
# city = 'Manhattan'
# root = '/root/DATASET/CrossModel/'
# root=Path(root)
self.uav_image_path = root/city/'uav'
self.map_path = root/city/'map'
self.map_vis = root / city / 'map_vis'
info_path = root / city / 'info.csv'
self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
self.transform=transform
self.training=training
def random_center_crop(self,image):
height, width = image.shape[:2]
# 随机生成剪裁尺寸
crop_size = random.randint(min(height, width) // 2, min(height, width))
# 计算剪裁的起始坐标
start_x = (width - crop_size) // 2
start_y = (height - crop_size) // 2
# 进行剪裁
cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size]
return cropped_image
def __getitem__(self, index: int):
id, uav_name, map_name, \
uav_long, uav_lat, \
map_long, map_lat, \
tile_size_meters, pixel_per_meter, \
u, v, yaw,dis=self.info[index]
uav_image=cv2.imread(str(self.uav_image_path/uav_name))
if self.training:
uav_image =self.random_center_crop(uav_image)
uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB)
if self.transform:
uav_image=self.transform(uav_image)
map=np.load(str(self.map_path/map_name))
return {
'map':torch.from_numpy(np.ascontiguousarray(map)).long(),
'image':torch.tensor(uav_image),
'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(),
'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(),
"uv":torch.tensor([float(u), float(v)]).float(),
}
def __len__(self):
return len(self.info)
if __name__ == '__main__':
root=Path('/root/DATASET/OrienterNet/UavMap/')
city='NewYork'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(256),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
dataset=UavMapPair(
root=root,
city=city,
transform=transform
)
datasetloder = DataLoader(dataset, batch_size=3)
for batch, i in enumerate(datasetloder):
pass
# 将PyTorch张量转换为PIL图像
# pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy())
# 显示图像
# 将PyTorch张量转换为NumPy数组
# numpy_array = i['uav_image'][0].numpy()
#
# # 显示图像
# plt.imshow(numpy_array.transpose(1, 2, 0))
# plt.axis('off')
# plt.show()
#
# map_viz, label = Colormap.apply(i['map'][0])
# map_viz = map_viz * 255
# map_viz = map_viz.astype(np.uint8)
# plot_images([map_viz], titles=["OpenStreetMap raster"])