import pickle

import matplotlib.pyplot as plt
import torch

position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0]
depth_num = 64
depth_start = 1


def position_embedding(features, img_features):
    eps = 1e-5
    img_features = img_features.unsqueeze(1)
    B, N, C, tar_H, tar_W = img_features.shape
    device = img_features.device
    crop_top = 28
    crop_left = 416
    H = [16 for _ in range(3)]
    W = [
        64 * 1088 // (1088 * 2 + 1920),
        64 * 1920 // (1088 * 2 + 1920),
        64 * 1088 // (1088 * 2 + 1920)
    ]

    # 左视图(16,17)
    coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0]
    coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0]
    # 前视图(16,30)
    coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1]
    coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1]
    # 右视图(16,17)
    coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2]
    coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2]

    index = torch.arange(start=0, end=depth_num, step=1, device=img_features.device).float()
    index_1 = index + 1
    bin_size = (position_range[3] - depth_start) / (depth_num * (1 + depth_num))
    coords_d = depth_start + bin_size * index * index_1

    D = coords_d.shape[0]
    coords = [1] * 3  # 0,1,2 -> front, left, right
    coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
    coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
    coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
    # coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)
    coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3], torch.ones_like(coords[0][..., 2:3]) * eps)
    coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3], torch.ones_like(coords[1][..., 2:3]) * eps)
    coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3], torch.ones_like(coords[2][..., 2:3]) * eps)

    # img_meta
    # img2lidars = ?
    pos_3d_embed = None
    for i in range(3):
        sensor2lidar_rotation = features["sensor2lidar_rotation"][i]
        sensor2lidar_translation = features["sensor2lidar_translation"][i]
        intrinsics = features["intrinsics"][i]
        combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float()  # (B, 1, 3, 3) ?
        # print(combine.shape)

        # coords_front,coords_fleft,coords_fright (W, H, D, 3)
        # coords3d = torch.stack((coords_front, coords_fleft, coords_fright), dim=0) # (N, W, H, D, 3) -> (B, N, W, H, D, 3, 1)
        # coords = coords.view(1, H, W, D, 1, 3).repeat(B, 1, 1, 1, 1, 1)
        coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1,
                                                                    1)  # (B, N, W, H, D, 3, 1) -> (B, N, W, H, D, 3, 3)
        combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1)
        coords3d = torch.matmul(combine, coords3d).squeeze(-1)  # (B, N, W, H, D, 3)
        sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3)
        coords3d += sensor2lidar_translation

        coords3d[..., 0:1] = (coords3d[..., 0:1] - position_range[0]) / (
                position_range[3] - position_range[0])
        coords3d[..., 1:2] = (coords3d[..., 1:2] - position_range[1]) / (
                position_range[4] - position_range[1])
        coords3d[..., 2:3] = (coords3d[..., 2:3] - position_range[2]) / (
                position_range[5] - position_range[2])
        # coords_mask = (coords3d > 1.0) | (coords3d < 0.0)
        # coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5)
        # coords_mask = coords_mask.permute(0, 1, 3, 2)
        #     for j in range(1000000):
        #         print(coords3d.shape)
        # (2, 1, 17, 16, 64, 3) -> (B, N, W, H, D, 3)
        # (2, 1, 30, 16, 64, 3)
        # -> (2, 1, 17+30+17, 16, 64, 3)
        # coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, H[i], W[i])
        if pos_3d_embed is None:
            pos_3d_embed = coords3d
        else:
            pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2)
    # for i in range(100000):
    # print(img_features.shape)
    pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W)
    return pos_3d_embed


if __name__ == '__main__':
    H, W = 16, 64
    logs = pickle.load(open('/mnt/g/navsim/navsim_logs/tiny/2021.05.12.22.28.35_veh-35_00620_01164.pkl', 'rb'))
    log = logs[0]
    features = {
        'sensor2lidar_rotation': [
            torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_rotation']),
            torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_rotation']),
            torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_rotation']),
        ],
        'sensor2lidar_translation': [
            torch.from_numpy(log['cams']['CAM_L0']['sensor2lidar_translation']),
            torch.from_numpy(log['cams']['CAM_F0']['sensor2lidar_translation']),
            torch.from_numpy(log['cams']['CAM_R0']['sensor2lidar_translation']),
        ],
        'intrinsics': [
            torch.from_numpy(log['cams']['CAM_L0']['cam_intrinsic']),
            torch.from_numpy(log['cams']['CAM_F0']['cam_intrinsic']),
            torch.from_numpy(log['cams']['CAM_R0']['cam_intrinsic']),
        ]
    }
    img_features = torch.randn((1, 3, H, W))
    coords_3d = position_embedding(
        features, img_features
    )
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    for i in range(H):
        for j in range(W):
            frustum_points = coords_3d.permute(2, 3, 1, 0).reshape(H, W, depth_num, 3)
            pixel_points = frustum_points[i, j]
            x_points = pixel_points[:, 0]
            y_points = pixel_points[:, 1]
            z_points = pixel_points[:, 2]
            ax.scatter(x_points, y_points, z_points)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.view_init(elev=90, azim=0)
    ax.set_zlabel('Z')
    plt.show()