# import numpy as np
# import plotly.graph_objects as go
# from scipy.interpolate import griddata

# def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
#     detectability = np.array(detectability_val)
#     distortion = np.array(distortion_val)
#     euclidean = np.array(euclidean_val)

#     # Find the closest point to the origin
#     distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1)
#     closest_point_index = np.argmin(distances_to_origin)

#     # Determine the closest points to each axis
#     closest_to_x_axis = np.argmin(distortion)
#     closest_to_y_axis = np.argmin(detectability)
#     closest_to_z_axis = np.argmin(euclidean)

#     # Use the detected closest point as the "sweet spot"
#     sweet_spot_detectability = detectability[closest_point_index]
#     sweet_spot_distortion = distortion[closest_point_index]
#     sweet_spot_euclidean = euclidean[closest_point_index]

#     # Create a meshgrid from the data
#     x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
#                                  np.linspace(min(distortion), max(distortion), 30))

#     # Interpolate z values (Euclidean distances) to fit the grid
#     z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')

#     if z_grid is None:
#         raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

#     # Create the 3D contour plot with the Plasma color scale
#     fig = go.Figure(data=go.Surface(
#         z=z_grid, 
#         x=x_grid, 
#         y=y_grid, 
#         contours={
#             "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
#         },
#         colorscale='Plasma'
#     ))

#     # Add a marker for the sweet spot
#     fig.add_trace(go.Scatter3d(
#         x=[sweet_spot_detectability],
#         y=[sweet_spot_distortion],
#         z=[sweet_spot_euclidean],
#         mode='markers+text',
#         marker=dict(size=10, color='red', symbol='circle'),
#         text=["Sweet Spot"],
#         textposition="top center"
#     ))

#     # Set axis labels
#     fig.update_layout(
#         scene=dict(
#             xaxis_title='Detectability Score',
#             yaxis_title='Distortion Score',
#             zaxis_title='Euclidean Distance'
#         ),
#         margin=dict(l=0, r=0, b=0, t=0)
#     )

#     return fig


import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata

def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
    detectability = np.array(detectability_val)
    distortion = np.array(distortion_val)
    euclidean = np.array(euclidean_val)

    # Normalize the values to range [0, 1]
    norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
    norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
    norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))

    # Composite score: maximize detectability, minimize distortion and Euclidean distance
    # We subtract distortion and euclidean as we want them minimized.
    composite_score = norm_detectability - (norm_distortion + norm_euclidean)

    # Find the index of the maximum score (sweet spot)
    sweet_spot_index = np.argmax(composite_score)

    # Sweet spot values
    sweet_spot_detectability = detectability[sweet_spot_index]
    sweet_spot_distortion = distortion[sweet_spot_index]
    sweet_spot_euclidean = euclidean[sweet_spot_index]

    # Create a meshgrid from the data
    x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
                                 np.linspace(min(distortion), max(distortion), 30))

    # Interpolate z values (Euclidean distances) to fit the grid
    z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')

    if z_grid is None:
        raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

    # Create the 3D contour plot with the Plasma color scale
    fig = go.Figure(data=go.Surface(
        z=z_grid, 
        x=x_grid, 
        y=y_grid, 
        contours={
            "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
        },
        colorscale='Plasma'
    ))

    # Add a marker for the sweet spot
    fig.add_trace(go.Scatter3d(
        x=[sweet_spot_detectability],
        y=[sweet_spot_distortion],
        z=[sweet_spot_euclidean],
        mode='markers+text',
        marker=dict(size=10, color='red', symbol='circle'),
        text=["Sweet Spot"],
        textposition="top center"
    ))

    # Set axis labels
    fig.update_layout(
        scene=dict(
            xaxis_title='Detectability Score',
            yaxis_title='Distortion Score',
            zaxis_title='Euclidean Distance'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig