import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def render_dataset_visualization(dataset, dataset_type):
    """
    Renders visualizations for the dataset.
    
    Args:
        dataset: The dataset to visualize (pandas DataFrame)
        dataset_type: The type of dataset (csv, json, etc.)
    """
    if dataset is None:
        st.warning("No dataset to visualize.")
        return
    
    st.markdown("<h3>Dataset Visualization</h3>", unsafe_allow_html=True)
    
    # Get column types
    numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist()
    categorical_cols = dataset.select_dtypes(include=['object', 'category']).columns.tolist()
    date_cols = [col for col in dataset.columns if dataset[col].dtype == 'datetime64[ns]']
    
    # Add visualization options based on column types
    viz_type = st.selectbox(
        "Select visualization type",
        ["Distribution", "Correlation", "Categories", "Time Series", "Custom"],
        help="Choose the type of visualization to create"
    )
    
    if viz_type == "Distribution":
        if numeric_cols:
            # Select columns for distribution visualization
            selected_cols = st.multiselect(
                "Select columns to visualize", 
                numeric_cols,
                default=numeric_cols[:min(3, len(numeric_cols))]
            )
            
            if not selected_cols:
                st.warning("Please select at least one column to visualize.")
                return
            
            # Distribution plots
            if len(selected_cols) == 1:
                # Single column histogram with density curve
                col = selected_cols[0]
                fig = px.histogram(
                    dataset, 
                    x=col,
                    histnorm='probability density',
                    title=f"Distribution of {col}",
                    color_discrete_sequence=["#FFD21E"],
                    template="simple_white"
                )
                fig.add_traces(
                    go.Scatter(
                        x=dataset[col].sort_values(),
                        y=dataset[col].sort_values().reset_index(drop=True).rolling(
                            window=int(len(dataset[col])/10) if len(dataset[col]) > 10 else len(dataset[col]),
                            min_periods=1,
                            center=True
                        ).mean(),
                        mode='lines',
                        line=dict(color="#2563EB", width=3),
                        name='Smoothed'
                    )
                )
                st.plotly_chart(fig, use_container_width=True)
            else:
                # Multiple histograms in a grid
                num_cols = min(len(selected_cols), 2)
                num_rows = (len(selected_cols) + num_cols - 1) // num_cols
                
                fig = make_subplots(
                    rows=num_rows, 
                    cols=num_cols,
                    subplot_titles=[f"Distribution of {col}" for col in selected_cols]
                )
                
                for i, col in enumerate(selected_cols):
                    row = i // num_cols + 1
                    col_pos = i % num_cols + 1
                    
                    # Add histogram
                    fig.add_trace(
                        go.Histogram(
                            x=dataset[col],
                            name=col,
                            marker_color="#FFD21E"
                        ),
                        row=row, col=col_pos
                    )
                
                fig.update_layout(
                    title="Distribution of Selected Features",
                    showlegend=False,
                    template="simple_white",
                    height=300 * num_rows
                )
                st.plotly_chart(fig, use_container_width=True)
            
            # Show distribution statistics
            st.markdown("### Distribution Statistics")
            stats_df = dataset[selected_cols].describe().T
            st.dataframe(stats_df, use_container_width=True)
        else:
            st.warning("No numeric columns found for distribution visualization.")
    
    elif viz_type == "Correlation":
        if len(numeric_cols) >= 2:
            # Correlation matrix
            st.markdown("### Correlation Matrix")
            
            # Select columns for correlation
            selected_cols = st.multiselect(
                "Select columns for correlation analysis", 
                numeric_cols,
                default=numeric_cols[:min(5, len(numeric_cols))]
            )
            
            if len(selected_cols) < 2:
                st.warning("Please select at least two columns for correlation analysis.")
                return
            
            # Compute correlation
            corr = dataset[selected_cols].corr()
            
            # Heatmap
            fig = px.imshow(
                corr,
                color_continuous_scale="RdBu_r",
                title="Correlation Matrix",
                template="simple_white",
                text_auto=True
            )
            st.plotly_chart(fig, use_container_width=True)
            
            # Scatter plot matrix for selected columns
            if len(selected_cols) > 2 and len(selected_cols) <= 5:  # Limit to 5 columns for readability
                st.markdown("### Scatter Plot Matrix")
                fig = px.scatter_matrix(
                    dataset,
                    dimensions=selected_cols,
                    color_discrete_sequence=["#2563EB"],
                    title="Scatter Plot Matrix",
                    template="simple_white"
                )
                fig.update_traces(diagonal_visible=False)
                st.plotly_chart(fig, use_container_width=True)
            
            # Correlation pairs as bar chart
            st.markdown("### Top Correlation Pairs")
            
            # Get correlation pairs
            corr_pairs = []
            for i in range(len(corr.columns)):
                for j in range(i+1, len(corr.columns)):
                    corr_pairs.append({
                        'Feature 1': corr.columns[i],
                        'Feature 2': corr.columns[j],
                        'Correlation': corr.iloc[i, j]
                    })
            
            # Sort by absolute correlation
            corr_pairs = sorted(corr_pairs, key=lambda x: abs(x['Correlation']), reverse=True)
            
            # Create bar chart
            if corr_pairs:
                # Convert to DataFrame
                corr_df = pd.DataFrame(corr_pairs)
                pair_labels = [f"{row['Feature 1']} & {row['Feature 2']}" for _, row in corr_df.iterrows()]
                
                # Bar chart
                fig = px.bar(
                    x=pair_labels,
                    y=[abs(c) for c in corr_df['Correlation']],
                    color=corr_df['Correlation'],
                    color_continuous_scale="RdBu_r",
                    labels={'x': 'Feature Pairs', 'y': 'Absolute Correlation'},
                    title="Top Feature Correlations"
                )
                st.plotly_chart(fig, use_container_width=True)
        else:
            st.warning("Need at least two numeric columns for correlation analysis.")
    
    elif viz_type == "Categories":
        if categorical_cols:
            # Select categorical column
            selected_cat = st.selectbox("Select categorical column", categorical_cols)
            
            # Category counts
            value_counts = dataset[selected_cat].value_counts()
            
            # Limit to top N categories if there are too many
            if len(value_counts) > 20:
                st.info(f"Showing top 20 categories out of {len(value_counts)}")
                value_counts = value_counts.head(20)
            
            # Bar chart
            fig = px.bar(
                x=value_counts.index,
                y=value_counts.values,
                title=f"Category Counts for {selected_cat}",
                labels={'x': selected_cat, 'y': 'Count'},
                color_discrete_sequence=["#FFD21E"]
            )
            st.plotly_chart(fig, use_container_width=True)
            
            # If there are numeric columns, show relationship with categorical
            if numeric_cols:
                st.markdown(f"### {selected_cat} vs Numeric Features")
                selected_num = st.selectbox("Select numeric column", numeric_cols)
                
                # Box plot
                fig = px.box(
                    dataset,
                    x=selected_cat,
                    y=selected_num,
                    title=f"{selected_cat} vs {selected_num}",
                    color_discrete_sequence=["#2563EB"],
                    template="simple_white"
                )
                st.plotly_chart(fig, use_container_width=True)
                
                # Statistics by category
                st.markdown(f"### Statistics of {selected_num} by {selected_cat}")
                stats_by_cat = dataset.groupby(selected_cat)[selected_num].describe()
                st.dataframe(stats_by_cat, use_container_width=True)
        else:
            st.warning("No categorical columns found for category visualization.")
    
    elif viz_type == "Time Series":
        # Check if there are potential date columns
        potential_date_cols = date_cols.copy()
        
        # Also check for object columns that might be dates
        for col in categorical_cols:
            # Sample the column to check if it contains date-like strings
            sample = dataset[col].dropna().head(5).tolist()
            if sample and all('/' in str(x) or '-' in str(x) for x in sample):
                potential_date_cols.append(col)
        
        if potential_date_cols:
            date_col = st.selectbox("Select date column", potential_date_cols)
            
            # Convert to datetime if it's not already
            if dataset[date_col].dtype != 'datetime64[ns]':
                try:
                    temp_df = dataset.copy()
                    temp_df[date_col] = pd.to_datetime(temp_df[date_col])
                except:
                    st.error(f"Could not convert {date_col} to datetime.")
                    return
            else:
                temp_df = dataset.copy()
            
            # Select numeric column for time series
            if numeric_cols:
                value_col = st.selectbox("Select value column", numeric_cols)
                
                # Aggregate by time period
                time_period = st.selectbox(
                    "Aggregate by",
                    ["Day", "Week", "Month", "Quarter", "Year"]
                )
                
                # Set up time grouping
                if time_period == "Day":
                    temp_df['period'] = temp_df[date_col].dt.date
                elif time_period == "Week":
                    temp_df['period'] = temp_df[date_col].dt.to_period('W').dt.start_time
                elif time_period == "Month":
                    temp_df['period'] = temp_df[date_col].dt.to_period('M').dt.start_time
                elif time_period == "Quarter":
                    temp_df['period'] = temp_df[date_col].dt.to_period('Q').dt.start_time
                else:  # Year
                    temp_df['period'] = temp_df[date_col].dt.year
                
                # Aggregate data
                agg_method = st.selectbox("Aggregation method", ["Mean", "Sum", "Min", "Max", "Count"])
                agg_map = {
                    "Mean": "mean",
                    "Sum": "sum",
                    "Min": "min",
                    "Max": "max",
                    "Count": "count"
                }
                
                time_series = temp_df.groupby('period')[value_col].agg(agg_map[agg_method]).reset_index()
                
                # Line chart
                fig = px.line(
                    time_series,
                    x='period',
                    y=value_col,
                    title=f"{agg_method} of {value_col} by {time_period}",
                    markers=True,
                    color_discrete_sequence=["#2563EB"],
                    template="simple_white"
                )
                fig.update_layout(
                    xaxis_title=time_period,
                    yaxis_title=f"{agg_method} of {value_col}"
                )
                st.plotly_chart(fig, use_container_width=True)
                
                # Show trendline option
                if st.checkbox("Show trendline"):
                    fig = px.scatter(
                        time_series,
                        x='period',
                        y=value_col,
                        trendline="ols",
                        title=f"{agg_method} of {value_col} by {time_period} with Trendline",
                        color_discrete_sequence=["#2563EB"],
                        template="simple_white"
                    )
                    fig.update_layout(
                        xaxis_title=time_period,
                        yaxis_title=f"{agg_method} of {value_col}"
                    )
                    st.plotly_chart(fig, use_container_width=True)
                
                # Table view of time series data
                st.dataframe(time_series, use_container_width=True)
            else:
                st.warning("No numeric columns found for time series values.")
        else:
            st.warning("No date columns found for time series visualization.")
    
    elif viz_type == "Custom":
        st.markdown("### Custom Visualization")
        st.info("Create a custom plot by selecting axes and plot type")
        
        # Select plot type
        plot_type = st.selectbox(
            "Select plot type",
            ["Scatter", "Line", "Bar", "Box", "Violin", "Histogram", "Pie", "3D Scatter"]
        )
        
        # Depending on the plot type, get required axes
        if plot_type in ["Scatter", "Line", "Bar", "3D Scatter"]:
            # For scatter/line/bar, we need x and y
            x_col = st.selectbox("X-axis", dataset.columns.tolist())
            y_col = st.selectbox("Y-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
            
            # For 3D scatter, we need a z-axis
            if plot_type == "3D Scatter":
                z_col = st.selectbox("Z-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
            
            # Optional color dimension
            use_color = st.checkbox("Add color dimension")
            color_col = None
            if use_color:
                color_col = st.selectbox("Color by", dataset.columns.tolist())
            
            # Create plot
            if plot_type == "Scatter":
                fig = px.scatter(
                    dataset,
                    x=x_col,
                    y=y_col,
                    color=color_col,
                    title=f"{y_col} vs {x_col}",
                    template="simple_white"
                )
            elif plot_type == "Line":
                fig = px.line(
                    dataset.sort_values(x_col),
                    x=x_col,
                    y=y_col,
                    color=color_col,
                    title=f"{y_col} vs {x_col}",
                    template="simple_white"
                )
            elif plot_type == "Bar":
                fig = px.bar(
                    dataset,
                    x=x_col,
                    y=y_col,
                    color=color_col,
                    title=f"{y_col} by {x_col}",
                    template="simple_white"
                )
            elif plot_type == "3D Scatter":
                fig = px.scatter_3d(
                    dataset,
                    x=x_col,
                    y=y_col,
                    z=z_col,
                    color=color_col,
                    title=f"3D Scatter: {x_col}, {y_col}, {z_col}",
                    template="simple_white"
                )
            
            st.plotly_chart(fig, use_container_width=True)
        
        elif plot_type in ["Box", "Violin"]:
            # For box/violin, we need x (categorical) and y (numeric)
            x_col = st.selectbox("X-axis (categories)", categorical_cols if categorical_cols else dataset.columns.tolist())
            y_col = st.selectbox("Y-axis (values)", numeric_cols if numeric_cols else dataset.columns.tolist())
            
            # Optional color dimension
            use_color = st.checkbox("Add color dimension")
            color_col = None
            if use_color:
                color_col = st.selectbox("Color by", dataset.columns.tolist())
            
            # Create plot
            if plot_type == "Box":
                fig = px.box(
                    dataset,
                    x=x_col,
                    y=y_col,
                    color=color_col,
                    title=f"Box Plot: {y_col} by {x_col}",
                    template="simple_white"
                )
            else:  # Violin
                fig = px.violin(
                    dataset,
                    x=x_col,
                    y=y_col,
                    color=color_col,
                    title=f"Violin Plot: {y_col} by {x_col}",
                    template="simple_white"
                )
            
            st.plotly_chart(fig, use_container_width=True)
        
        elif plot_type == "Histogram":
            # For histogram, we need just one column
            value_col = st.selectbox("Value column", dataset.columns.tolist())
            
            # Bins option
            n_bins = st.slider("Number of bins", 5, 100, 20)
            
            # Optional color dimension
            use_color = st.checkbox("Add color dimension")
            color_col = None
            if use_color:
                color_col = st.selectbox("Color by", dataset.columns.tolist())
            
            # Create plot
            fig = px.histogram(
                dataset,
                x=value_col,
                color=color_col,
                nbins=n_bins,
                title=f"Histogram of {value_col}",
                template="simple_white"
            )
            
            st.plotly_chart(fig, use_container_width=True)
        
        elif plot_type == "Pie":
            # For pie, we need a categorical column
            cat_col = st.selectbox("Category column", categorical_cols if categorical_cols else dataset.columns.tolist())
            
            # Optional value column
            use_values = st.checkbox("Use custom values")
            value_col = None
            if use_values and numeric_cols:
                value_col = st.selectbox("Value column", numeric_cols)
            
            # Limit to top N categories if there are too many
            top_n = st.slider("Limit to top N categories", 0, 20, 10, 
                help="Set to 0 to show all categories. Recommended to limit to top 10-15 categories for readability.")
            
            # Process data for pie chart
            if top_n > 0:
                if use_values and value_col:
                    pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
                    pie_data = pie_data.sort_values(value_col, ascending=False).head(top_n)
                else:
                    value_counts = dataset[cat_col].value_counts().reset_index()
                    value_counts.columns = [cat_col, 'count']
                    pie_data = value_counts.head(top_n)
                    value_col = 'count'
            else:
                if use_values and value_col:
                    pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
                else:
                    value_counts = dataset[cat_col].value_counts().reset_index()
                    value_counts.columns = [cat_col, 'count']
                    pie_data = value_counts
                    value_col = 'count'
            
            # Create plot
            fig = px.pie(
                pie_data,
                names=cat_col,
                values=value_col,
                title=f"Pie Chart of {cat_col}",
                template="simple_white"
            )
            
            st.plotly_chart(fig, use_container_width=True)