import torch from torch.utils.data import Dataset import os import cv2 # @Time : 2023-02-13 22:56 # @Author : Wang Zhen # @Email : frozenzhencola@163.com # @File : SatelliteTool.py # @Project : TGRS_seqmatch_2023_1 import numpy as np import random from utils.geo import BoundaryBox, Projection from osm.tiling import TileManager,MapTileManager from pathlib import Path from torchvision import transforms from torch.utils.data import DataLoader class UavMapPair(Dataset): def __init__( self, root: Path, city:str, training:bool, transform ): super().__init__() # self.root = root # city = 'Manhattan' # root = '/root/DATASET/CrossModel/' # root=Path(root) self.uav_image_path = root/city/'uav' self.map_path = root/city/'map' self.map_vis = root / city / 'map_vis' info_path = root / city / 'info.csv' self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) self.transform=transform self.training=training def random_center_crop(self,image): height, width = image.shape[:2] # 随机生成剪裁尺寸 crop_size = random.randint(min(height, width) // 2, min(height, width)) # 计算剪裁的起始坐标 start_x = (width - crop_size) // 2 start_y = (height - crop_size) // 2 # 进行剪裁 cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size] return cropped_image def __getitem__(self, index: int): id, uav_name, map_name, \ uav_long, uav_lat, \ map_long, map_lat, \ tile_size_meters, pixel_per_meter, \ u, v, yaw,dis=self.info[index] uav_image=cv2.imread(str(self.uav_image_path/uav_name)) if self.training: uav_image =self.random_center_crop(uav_image) uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB) if self.transform: uav_image=self.transform(uav_image) map=np.load(str(self.map_path/map_name)) return { 'map':torch.from_numpy(np.ascontiguousarray(map)).long(), 'image':torch.tensor(uav_image), 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(), 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(), "uv":torch.tensor([float(u), float(v)]).float(), } def __len__(self): return len(self.info) if __name__ == '__main__': root=Path('/root/DATASET/OrienterNet/UavMap/') city='NewYork' transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize(256), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset=UavMapPair( root=root, city=city, transform=transform ) datasetloder = DataLoader(dataset, batch_size=3) for batch, i in enumerate(datasetloder): pass # 将PyTorch张量转换为PIL图像 # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy()) # 显示图像 # 将PyTorch张量转换为NumPy数组 # numpy_array = i['uav_image'][0].numpy() # # # 显示图像 # plt.imshow(numpy_array.transpose(1, 2, 0)) # plt.axis('off') # plt.show() # # map_viz, label = Colormap.apply(i['map'][0]) # map_viz = map_viz * 255 # map_viz = map_viz.astype(np.uint8) # plot_images([map_viz], titles=["OpenStreetMap raster"])