import streamlit as st
import geopandas as gpd
import pandas as pd
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import warnings
import io
from shapely.geometry import Polygon, MultiPolygon

# Page configuration
st.set_page_config(layout="wide", page_title="Geographic Data Visualization")

# Suppress warnings
warnings.filterwarnings('ignore')

def load_data():
    # Load world map data
    world_gdf = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
    return world_gdf

def main():
    st.title("🗺️ Geographic Data Visualization")
    
    # Load data
    world_gdf = load_data()
    
    # Country selection
    selected_country = st.selectbox(
        "Select a Country",
        sorted(world_gdf['name'].unique())
    )
    
    # Filter data for selected country
    gdf = world_gdf[world_gdf['name'] == selected_country].copy()
    
    # Legend settings
    st.sidebar.header("Legend Settings")
    legend_title = st.sidebar.text_input("Legend Title", "Regions")
    legend_location = st.sidebar.selectbox("Legend Location", 
                                         ["best", "upper right", "upper left", "lower left", "lower right"],
                                         index=0)
    legend_font_size = st.sidebar.slider("Legend Font Size", 8, 20, 12)
    legend_font_weight = st.sidebar.selectbox("Legend Font Weight", ["Normal", "Bold"], index=0)
    legend_columns = st.sidebar.slider("Legend Columns", 1, 3, 1)
    legend_transparency = st.sidebar.slider("Legend Transparency", 0.0, 1.0, 0.1)
    legend_bg_color = st.sidebar.color_picker("Legend Background Color", "#FFFFFF")
    legend_border_color = st.sidebar.color_picker("Legend Border Color", "#000000")
    legend_border_width = st.sidebar.slider("Legend Border Width", 0, 5, 1)
    
    # Special handling for Pakistan
    if selected_country == "Pakistan":
        # Define administrative levels for Pakistan
        admin_levels = {}
        for i in range(6):
            col_name = f"NAME_{i}"
            if col_name in gdf.columns and not gdf[col_name].isna().all():
                level_name = f"Administrative Level {i}"
                admin_levels[level_name] = col_name
        
        admin_levels["Provinces and Territories"] = "NAME_1"
        
        # Define Pakistan-specific regions
        pakistan_regions = {
            "Provinces": ["Punjab", "Sindh", "Khyber Pakhtunkhwa", "Balochistan"],
            "Territories": ["Islamabad Capital Territory", "Gilgit-Baltistan", "Azad Kashmir"]
        }
        
        # Selection for Pakistan's administrative levels
        selected_admin_levels = st.multiselect(
            "Select Administrative Levels to Display",
            list(admin_levels.keys()),
            default=["Provinces and Territories"]
        )
        
        # Initialize selected areas and colors
        selected_areas = []
        fill_colors = {}
        border_colors = {}
        
        # Handle selection for Pakistan's provinces and territories
        if "Provinces and Territories" in selected_admin_levels:
            col1, col2 = st.columns(2)
            
            with col1:
                selected_provinces = st.multiselect(
                    "Select Provinces",
                    pakistan_regions["Provinces"],
                    default=pakistan_regions["Provinces"]
                )
            
            with col2:
                selected_territories = st.multiselect(
                    "Select Territories",
                    pakistan_regions["Territories"],
                    default=pakistan_regions["Territories"]
                )
            
            # Process selections
            for area in selected_provinces + selected_territories:
                selected_areas.append(area)
                fill_colors[area] = st.color_picker(f"Fill color for {area}", "#3498db")
                border_colors[area] = st.color_picker(f"Border color for {area}", "#2980b9")
        
        # Create visualizations
        st.subheader("2D Map Visualization")
        fig_2d = create_2d_map(gdf, selected_areas, fill_colors, border_colors,
                              legend_title, legend_location, legend_font_size,
                              legend_bg_color, legend_border_color, legend_transparency,
                              legend_columns)
        st.pyplot(fig_2d)
        
        st.subheader("3D Map Visualization")
        fig_3d = create_3d_map(gdf, selected_areas, fill_colors, border_colors,
                              legend_title, legend_font_size, legend_font_weight,
                              legend_bg_color, legend_border_color, legend_border_width,
                              legend_columns)
        st.plotly_chart(fig_3d, use_container_width=True)
        
        # Download buttons
        col1, col2 = st.columns(2)
        with col1:
            if st.button("Download 2D Map"):
                plt.savefig("map_2d.png", dpi=300, bbox_inches='tight')
                with open("map_2d.png", "rb") as file:
                    st.download_button(
                        label="Click to Download 2D Map",
                        data=file,
                        file_name="map_2d.png",
                        mime="image/png"
                    )
        
        with col2:
            if st.button("Download 3D Map"):
                html_buffer = io.StringIO()
                fig_3d.write_html(html_buffer)
                st.download_button(
                    label="Click to Download 3D Map",
                    data=html_buffer.getvalue(),
                    file_name="map_3d.html",
                    mime="text/html"
                )

def create_2d_map(gdf, selected_areas, fill_colors, border_colors,
                  legend_title, legend_location, legend_font_size,
                  legend_bg_color, legend_border_color, legend_transparency,
                  legend_columns):
    fig, ax = plt.subplots(figsize=(15, 10))
    
    # Plot base map
    gdf.plot(ax=ax, color='#EEEEEE', alpha=0.3, edgecolor='#CCCCCC')
    
    # Plot selected areas
    for area in selected_areas:
        area_gdf = gdf[gdf['NAME_1'] == area]
        area_gdf.plot(ax=ax, color=fill_colors[area], alpha=0.5,
                     edgecolor=border_colors[area], linewidth=1)
    
    # Configure legend
    patches = [plt.Rectangle((0, 0), 1, 1, fc=fill_colors[area],
                           ec=border_colors[area], alpha=0.5)
               for area in selected_areas]
    
    ax.legend(patches, selected_areas,
             title=legend_title,
             loc=legend_location,
             fontsize=legend_font_size,
             title_fontsize=legend_font_size + 2,
             frameon=True,
             facecolor=legend_bg_color,
             edgecolor=legend_border_color,
             framealpha=1-legend_transparency,
             ncol=legend_columns)
    
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
    
    return fig

def create_3d_map(gdf, selected_areas, fill_colors, border_colors,
                  legend_title, legend_font_size, legend_font_weight,
                  legend_bg_color, legend_border_color, legend_border_width,
                  legend_columns):
    fig = go.Figure()
    added_to_legend = set()
    
    # Plot base map
    for _, row in gdf.iterrows():
        if isinstance(row.geometry, MultiPolygon):
            for poly in row.geometry.geoms:
                x, y = poly.exterior.xy
                fig.add_trace(go.Scatter3d(
                    x=list(x), y=list(y), z=[0]*len(x),
                    mode='lines',
                    line=dict(color='#CCCCCC', width=1),
                    showlegend=False
                ))
        else:
            x, y = row.geometry.exterior.xy
            fig.add_trace(go.Scatter3d(
                x=list(x), y=list(y), z=[0]*len(x),
                mode='lines',
                line=dict(color='#CCCCCC', width=1),
                showlegend=False
            ))
    
    # Plot selected areas
    for area in selected_areas:
        area_gdf = gdf[gdf['NAME_1'] == area]
        for _, row in area_gdf.iterrows():
            if isinstance(row.geometry, MultiPolygon):
                for poly in row.geometry.geoms:
                    x, y = poly.exterior.xy
                    if area not in added_to_legend:
                        fig.add_trace(go.Scatter3d(
                            x=list(x), y=list(y), z=[0]*len(x),
                            mode='lines+markers',
                            line=dict(color=border_colors[area], width=2),
                            marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                            name=area,
                            showlegend=True
                        ))
                        added_to_legend.add(area)
                    else:
                        fig.add_trace(go.Scatter3d(
                            x=list(x), y=list(y), z=[0]*len(x),
                            mode='lines+markers',
                            line=dict(color=border_colors[area], width=2),
                            marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                            name=area,
                            showlegend=False
                        ))
            else:
                x, y = row.geometry.exterior.xy
                if area not in added_to_legend:
                    fig.add_trace(go.Scatter3d(
                        x=list(x), y=list(y), z=[0]*len(x),
                        mode='lines+markers',
                        line=dict(color=border_colors[area], width=2),
                        marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                        name=area,
                        showlegend=True
                    ))
                    added_to_legend.add(area)
                else:
                    fig.add_trace(go.Scatter3d(
                        x=list(x), y=list(y), z=[0]*len(x),
                        mode='lines+markers',
                        line=dict(color=border_colors[area], width=2),
                        marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                        name=area,
                        showlegend=False
                    ))
    
    # Update layout
    fig.update_layout(
        scene=dict(
            aspectmode='data',
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-1, 1])
        ),
        margin=dict(t=30, b=0, l=0, r=0),
        legend=dict(
            font=dict(
                size=legend_font_size,
                weight=legend_font_weight.lower()
            ),
            bgcolor=legend_bg_color,
            bordercolor=legend_border_color,
            borderwidth=legend_border_width,
            orientation="h" if legend_columns > 1 else "v",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
            title=dict(
                text=legend_title,
                font=dict(
                    size=legend_font_size + 2
                )
            ),
            itemsizing='constant'
        ),
        showlegend=True,
        height=800
    )
    
    return fig

if __name__ == "__main__":
    main()
    # [Previous code remains the same until line 285 in the main() function]
# Add after the 3D map visualization section in main():

        # 3D Map Visualization (World Map Version)
        st.subheader("🌎 3D Map Visualization (World Map)")
        
        fig_world = create_3d_world_map(world_gdf, gdf, selected_areas, fill_colors, border_colors,
                                    legend_title, legend_font_size, legend_font_weight,
                                    legend_bg_color, legend_border_color, legend_border_width,
                                    legend_columns)
        st.plotly_chart(fig_world, use_container_width=True)
        
        # Add download buttons for both 3D maps
        col1, col2 = st.columns(2)
        with col1:
            if st.button("Download 3D Map (Current Version)"):
                html_buffer = io.StringIO()
                fig_3d.write_html(html_buffer)
                st.download_button(
                    label="Click to Download 3D Map (Current Version)",
                    data=html_buffer.getvalue(),
                    file_name=f"{selected_country.lower()}_map_3d.html",
                    mime="text/html"
                )
        
        with col2:
            if st.button("Download 3D World Map"):
                html_buffer = io.StringIO()
                fig_world.write_html(html_buffer)
                st.download_button(
                    label="Click to Download 3D World Map",
                    data=html_buffer.getvalue(),
                    file_name=f"{selected_country.lower()}_world_map_3d.html",
                    mime="text/html"
                )



# Add new function for world map visualization after create_3d_map():

def create_3d_world_map(world_gdf, selected_gdf, selected_areas, fill_colors, border_colors,
                     legend_title, legend_font_size, legend_font_weight,
                     legend_bg_color, legend_border_color, legend_border_width,
                     legend_columns):
    fig_world = go.Figure()
    added_to_legend = set()

    # Plot all countries in light gray
    for _, row in world_gdf.iterrows():
        if isinstance(row.geometry, MultiPolygon):
            for poly in row.geometry.geoms:
                x, y = poly.exterior.xy
                fig_world.add_trace(go.Scatter3d(
                    x=list(x), y=list(y), z=[0]*len(x),
                    mode='lines',
                    line=dict(color='#CCCCCC', width=1),
                    showlegend=False
                ))
        else:
            x, y = row.geometry.exterior.xy
            fig_world.add_trace(go.Scatter3d(
                x=list(x), y=list(y), z=[0]*len(x),
                mode='lines',
                line=dict(color='#CCCCCC', width=1),
                showlegend=False
            ))

    # Plot selected areas with their colors
    for area in selected_areas:
        area_gdf = selected_gdf[selected_gdf['NAME_1'] == area]
        for _, row in area_gdf.iterrows():
            if isinstance(row.geometry, MultiPolygon):
                for poly in row.geometry.geoms:
                    x, y = poly.exterior.xy
                    if area not in added_to_legend:
                        fig_world.add_trace(go.Scatter3d(
                            x=list(x), y=list(y), z=[0]*len(x),
                            mode='lines+markers',
                            line=dict(color=border_colors[area], width=2),
                            marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                            name=area,
                            showlegend=True
                        ))
                        added_to_legend.add(area)
                    else:
                        fig_world.add_trace(go.Scatter3d(
                            x=list(x), y=list(y), z=[0]*len(x),
                            mode='lines+markers',
                            line=dict(color=border_colors[area], width=2),
                            marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                            name=area,
                            showlegend=False
                        ))
                    fig_world.add_trace(go.Mesh3d(
                        x=list(x), y=list(y), z=[0]*len(x),
                        color=fill_colors[area],
                        opacity=0.3,
                        showlegend=False
                    ))
            else:
                x, y = row.geometry.exterior.xy
                if area not in added_to_legend:
                    fig_world.add_trace(go.Scatter3d(
                        x=list(x), y=list(y), z=[0]*len(x),
                        mode='lines+markers',
                        line=dict(color=border_colors[area], width=2),
                        marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                        name=area,
                        showlegend=True
                    ))
                    added_to_legend.add(area)
                else:
                    fig_world.add_trace(go.Scatter3d(
                        x=list(x), y=list(y), z=[0]*len(x),
                        mode='lines+markers',
                        line=dict(color=border_colors[area], width=2),
                        marker=dict(size=1, color=fill_colors[area], opacity=0.5),
                        name=area,
                        showlegend=False
                    ))
                fig_world.add_trace(go.Mesh3d(
                    x=list(x), y=list(y), z=[0]*len(x),
                    color=fill_colors[area],
                    opacity=0.3,
                    showlegend=False
                ))

    # Update layout
    fig_world.update_layout(
        scene=dict(
            aspectmode='data',
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            zaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-1, 1])
        ),
        margin=dict(t=30, b=0, l=0, r=0),
        legend=dict(
            font=dict(
                size=legend_font_size,
                weight=legend_font_weight.lower()
            ),
            bgcolor=legend_bg_color,
            bordercolor=legend_border_color,
            borderwidth=legend_border_width,
            orientation="h" if legend_columns > 1 else "v",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
            title=dict(
                text=legend_title,
                font=dict(
                    size=legend_font_size + 2
                )
            ),
            itemsizing='constant',
            traceorder='normal' if legend_columns == 1 else 'grouped'
        ),
        showlegend=True,
        height=800
    )
    return fig_world


# Wrap the main execution in a try-except block
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        st.error(f"⚠️ Error: {e}")