import numpy as np
import open3d as o3d
import torch
from tqdm import tqdm
import torch.nn.functional as F

def pts2normal(pts):
    h, w, _ = pts.shape
    
    # Compute differences in x and y directions
    dx = torch.cat([pts[2:, 1:-1] - pts[:-2, 1:-1]], dim=0)
    dy = torch.cat([pts[1:-1, 2:] - pts[1:-1, :-2]], dim=1)
    
    # Compute normal vectors using cross product
    normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
    
    # Create padded normal map
    padded_normal_map = torch.zeros_like(pts)
    padded_normal_map[1:-1, 1:-1, :] = normal_map

    # Pad the borders
    padded_normal_map[0, 1:-1, :] = normal_map[0, :, :]  # Top edge
    padded_normal_map[-1, 1:-1, :] = normal_map[-1, :, :]  # Bottom edge
    padded_normal_map[1:-1, 0, :] = normal_map[:, 0, :]  # Left edge
    padded_normal_map[1:-1, -1, :] = normal_map[:, -1, :]  # Right edge
    
    # Pad the corners
    padded_normal_map[0, 0, :] = normal_map[0, 0, :]  # Top-left corner
    padded_normal_map[0, -1, :] = normal_map[0, -1, :]  # Top-right corner
    padded_normal_map[-1, 0, :] = normal_map[-1, 0, :]  # Bottom-left corner
    padded_normal_map[-1, -1, :] = normal_map[-1, -1, :]  # Bottom-right corner
    
    return padded_normal_map

def point2mesh(pcd, depth=8, density_threshold=0.1, clean_mesh=True):
    print("\nPerforming Poisson surface reconstruction...")
    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
        pcd, depth=depth, width=0, scale=1.1, linear_fit=False)

    print(f"Reconstructed mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")

    # Normalize densities
    densities = np.asarray(densities)
    densities = (densities - densities.min()) / (densities.max() - densities.min())

    # Remove low density vertices
    print("\nPruning low-density vertices...")
    vertices_to_remove = densities < np.quantile(densities, density_threshold)
    mesh.remove_vertices_by_mask(vertices_to_remove)

    print(f"Pruned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")

    if clean_mesh:
        print("\nCleaning the mesh...")
        mesh.remove_degenerate_triangles()
        mesh.remove_duplicated_triangles()
        mesh.remove_duplicated_vertices()
        mesh.remove_non_manifold_edges()

        print(f"Final cleaned mesh has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles")

    mesh.compute_triangle_normals()
    return mesh

def combine_and_clean_point_clouds(pcds, voxel_size):
    """
    Combine, downsample, and clean a list of point clouds.

    Parameters:
    pcds (list): List of open3d.geometry.PointCloud objects to be processed.
    voxel_size (float): The size of the voxel for downsampling.

    Returns:
    o3d.geometry.PointCloud: The cleaned and combined point cloud.
    """
    print("\nCombining point clouds...")
    pcd_combined = o3d.geometry.PointCloud()
    for p3d in pcds:
        pcd_combined += p3d
    
    print("\nDownsampling the combined point cloud...")
    pcd_combined = pcd_combined.voxel_down_sample(voxel_size)
    print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")

    print("\nCleaning the combined point cloud...")
    cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
    pcd_cleaned = pcd_combined.select_by_index(ind)

    print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
    print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")

    return pcd_cleaned

def improved_multiway_registration(pcds, descriptors=None, voxel_size=0.05, 
                                   max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None, 
                                   overlap=5, quadratic_overlap=False, use_colored_icp=False):
    if max_correspondence_distance_coarse is None:
        max_correspondence_distance_coarse = voxel_size * 1.5
    if max_correspondence_distance_fine is None:
        max_correspondence_distance_fine = voxel_size * 0.15

    def pairwise_registration(source, target, use_colored_icp, max_correspondence_distance_coarse, max_correspondence_distance_fine):
        current_transformation = np.identity(4)
        try:
            if use_colored_icp:
                icp_fine = o3d.pipelines.registration.registration_colored_icp(
                        source, target, max_correspondence_distance_fine, current_transformation,
                        o3d.pipelines.registration.TransformationEstimationForColoredICP(),
                        o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
                                                                      relative_rmse=1e-6,
                                                                      max_iteration=100))
            else:
                icp_fine = o3d.pipelines.registration.registration_icp(
                    source, target, max_correspondence_distance_fine,
                    current_transformation,
                    o3d.pipelines.registration.TransformationEstimationPointToPlane())
            
            
            fitness = icp_fine.fitness
            FITNESS_THRESHOLD = 0.01

            if fitness >= FITNESS_THRESHOLD:
                current_transformation = icp_fine.transformation
            
                information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
                    source, target, max_correspondence_distance_fine,
                    current_transformation)
                return current_transformation, information_icp, True
            else:
                print(f"Registration failed. Fitness {fitness} is below threshold {FITNESS_THRESHOLD}")
                return None, None, False
        
        except RuntimeError as e:
            print(f"  ICP registration failed: {str(e)}")
            return None, None, False

    def detect_loop_closure(descriptors, min_interval=3, similarity_threshold=0.9):
        n_pcds = len(descriptors)
        loop_edges = []
        
        for i in range(n_pcds):
            for j in range(i + min_interval, n_pcds):
                similarity = torch.dot(descriptors[i], descriptors[j])
                if similarity > similarity_threshold:
                    loop_edges.append((i, j))
        
        return loop_edges

    def generate_pairs(n_pcds, overlap, quadratic_overlap):
        pairs = []
        for i in range(n_pcds - 1):
            for j in range(i + 1, min(i + overlap + 1, n_pcds)):
                pairs.append((i, j))
                if quadratic_overlap:
                    q = 2**(j-i)
                    if q > overlap and i + q < n_pcds:
                        pairs.append((i, i + q))
        return pairs

    def full_registration(pcds_down, pairs, loop_edges):
        pose_graph = o3d.pipelines.registration.PoseGraph()
        n_pcds = len(pcds_down)
        
        for i in range(n_pcds):
            pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(np.identity(4)))

        print("\nPerforming pairwise registration:")
        for source_id, target_id in tqdm(pairs):
            transformation_icp, information_icp, success = pairwise_registration(
                pcds_down[source_id], pcds_down[target_id], use_colored_icp,
                max_correspondence_distance_coarse, max_correspondence_distance_fine)
            
            if success:            
                uncertain = abs(target_id - source_id) == 1
                pose_graph.edges.append(
                    o3d.pipelines.registration.PoseGraphEdge(source_id,
                                                            target_id,
                                                            transformation_icp,
                                                            information_icp,
                                                            uncertain=uncertain))
            else:
                print(f"  Skipping edge between {source_id} and {target_id} due to ICP failure")

        # Add loop closure edges
        print("\nAdding loop closure edges:")
        for source_id, target_id in tqdm(loop_edges):
            transformation_icp, information_icp, success = pairwise_registration(
                pcds_down[source_id], pcds_down[target_id], use_colored_icp,
                max_correspondence_distance_coarse, max_correspondence_distance_fine)

            if success:
                pose_graph.edges.append(
                    o3d.pipelines.registration.PoseGraphEdge(source_id,
                                                            target_id,
                                                            transformation_icp,
                                                            information_icp,
                                                            uncertain=True))
            else:
                print(f"  Skipping loop closure edge between {source_id} and {target_id} due to ICP failure")

        return pose_graph

    print("\n--- Improved Multiway Registration Process ---")
    print(f"Number of point clouds: {len(pcds)}")
    print(f"Voxel size: {voxel_size}")
    print(f"Max correspondence distance (coarse): {max_correspondence_distance_coarse}")
    print(f"Max correspondence distance (fine): {max_correspondence_distance_fine}")
    print(f"Overlap: {overlap}")
    print(f"Quadratic overlap: {quadratic_overlap}")

    print("\nPreprocessing point clouds...")
    pcds_down = pcds
    print(f"Preprocessing complete. {len(pcds_down)} point clouds processed.")

    print("\nGenerating initial graph pairs...")
    pairs = generate_pairs(len(pcds), overlap, quadratic_overlap)
    print(f"Generated {len(pairs)} pairs for initial graph.")

    if descriptors is None:
        print("\nNo descriptors provided. Skipping loop closure detection.")
        loop_edges = []
    else:
        print(descriptors[0].shape)
        print("\nDetecting loop closures...")
        loop_edges = detect_loop_closure(descriptors)
        print(f"Detected {len(loop_edges)} loop closures.")

    print("\nPerforming full registration...")
    pose_graph = full_registration(pcds_down, pairs, loop_edges)

    print("\nOptimizing PoseGraph...")
    option = o3d.pipelines.registration.GlobalOptimizationOption(
        max_correspondence_distance=max_correspondence_distance_fine,
        edge_prune_threshold=0.25,
        reference_node=0)
    o3d.pipelines.registration.global_optimization(
        pose_graph,
        o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
        o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
        option)

    # Count edges for each node
    edge_count = {i: 0 for i in range(len(pcds))}
    for edge in pose_graph.edges:
        edge_count[edge.source_node_id] += 1
        edge_count[edge.target_node_id] += 1

    # Filter nodes with more than 3 edges
    valid_nodes = [count > 3 for count in edge_count.values()]
    
    print("\nTransforming and combining point clouds...")
    pcd_combined = o3d.geometry.PointCloud()

    for point_id, is_valid in enumerate(valid_nodes):
        if is_valid:
            pcds[point_id].transform(pose_graph.nodes[point_id].pose)
            pcd_combined += pcds[point_id]
        else:
            print(f"Skipping point cloud {point_id} as it has {edge_count[point_id]} edges (<=3)")
            
    print("\nDownsampling the combined point cloud...")
    # pcd_combined.orient_normals_consistent_tangent_plane(k=30)
    pcd_combined = pcd_combined.voxel_down_sample(voxel_size * 0.1)
    print(f"Downsampled from {len(pcd_combined.points)} to {len(pcd_combined.points)} points")

    print("\nCleaning the combined point cloud...")
    cl, ind = pcd_combined.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
    pcd_cleaned = pcd_combined.select_by_index(ind)

    print(f"Cleaned point cloud contains {len(pcd_cleaned.points)} points.")
    print(f"Removed {len(pcd_combined.points) - len(pcd_cleaned.points)} outlier points.")

    print("\nMultiway registration complete.")
    print(f"Included {len(valid_nodes)} out of {len(pcds)} point clouds (with >3 edges).")
    
    return pcd_cleaned, pose_graph, valid_nodes