import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d

# Function to save and download images with specified DPI
def download_button(fig, file_name, file_format, dpi):
    buffer = BytesIO()
    fig.savefig(buffer, format=file_format.lower(), dpi=dpi, bbox_inches="tight")
    buffer.seek(0)
    st.download_button(
        label=f"Download as {file_format.upper()} ({dpi} DPI)",
        data=buffer,
        file_name=f"{file_name}_{dpi}dpi.{file_format.lower()}",
        mime=f"image/{file_format.lower()}",
    )

# Function to apply smoothing
def apply_smoothing(data, method, **params):
    if method == "None":
        return data
    elif method == "Moving Average":
        window = params.get('window_size', 5)
        return pd.Series(data).rolling(window=window, center=True).mean()
    elif method == "Gaussian":
        sigma = params.get('sigma', 2)
        return gaussian_filter1d(data, sigma=sigma)
    elif method == "Savitzky-Golay":
        window = params.get('window_size', 5)
        poly_order = params.get('poly_order', 2)
        return savgol_filter(data, window_length=window, polyorder=poly_order)
    return data

# Title of the app
st.set_page_config(layout="wide")  # Use wide layout
st.title("Advanced CSV Data Visualization App")

# File uploader
uploaded_file = st.file_uploader("Upload your CSV file", type="csv")

if uploaded_file is not None:
    try:
        # Read the CSV file
        df = pd.read_csv(uploaded_file)

        # Function to clean non-numeric values and convert to float
        def clean_column(column):
            return pd.to_numeric(column.str.replace(r"[^\d.-]", "", regex=True), errors='coerce')

        # Clean all columns to handle non-numeric data
        df = df.apply(lambda col: clean_column(col) if col.dtype == "object" else col)

        # Sidebar with settings
        with st.sidebar:
            st.subheader("Graph Settings")
            col1, col2 = st.columns(2)

            with col1:
                # X-Axis Settings
                x_column = st.selectbox("X-axis Column:", options=df.columns)
                x_label = st.text_input("X-axis Label:", value="Frequency")
                x_unit = st.text_input("X-axis Unit (e.g., Hz, MHz):", value="GHz")
                scale_option = st.selectbox("X Scaling Option:", 
                                            ["None", "Hz → MHz", "Hz → GHz", "mm → cm", "mm → m", 
                                             "cm → mm", "m → mm", "Frequency → Wavelength", "Wavelength → Frequency", 
                                             "Linear → dB", "dB → Linear"], index=2)
                x_min = st.number_input("Lower X-axis Limit:", value=None, format="%f")
                x_max = st.number_input("Upper X-axis Limit:", value=None, format="%f")
                x_step = st.number_input("X-axis Step Size:", value=5.0, format="%f")
                x_font_size = st.slider("X-axis Font Size:", 8, 20, 14)
                x_tick_position = st.selectbox("X-axis Tick Position:", ["out", "in", "inout"], index=2)

            with col2:
                # Y-Axis Settings
                y_columns = st.multiselect("Y-axis Column(s):", options=df.columns)
                y_label = st.text_input("Y-axis Label:", value="S21")
                y_unit = st.text_input("Y-axis Unit (e.g., dB, Linear):", value="dB")
                y_scale_option = st.selectbox("Y Scaling Option:", 
                                              ["None", "Hz → MHz", "Hz → GHz", "mm → cm", "mm → m", 
                                               "cm → mm", "m → mm", "Frequency → Wavelength", "Wavelength → Frequency", 
                                               "Linear → dB", "dB → Linear"], index=0)
                y_min = st.number_input("Lower Y-axis Limit:", value=None, format="%f")
                y_max = st.number_input("Upper Y-axis Limit:", value=None, format="%f")
                y_step = st.number_input("Y-axis Step Size:", value=10.0, format="%f")
                y_font_size = st.slider("Y-axis Font Size:", 8, 20, 14)
                y_tick_position = st.selectbox("Y-axis Tick Position:", ["out", "in", "inout"], index=2)

            # Title and Font Style Settings
            st.subheader("Title and Font Settings")
            title = st.text_input("Graph Title:", value="Advanced Graph")
            font_style = st.selectbox("Font Style:", ["Normal", "Italic", "Bold"], index=0).lower()
            font_theme = st.selectbox("Font Theme:", ["Times New Roman", "Arial", "Courier New", "Helvetica", "Verdana"], index=0)
            title_font_size = st.slider("Title Font Size:", 10, 30, 16)

            # Grid Settings
            st.subheader("Grid Settings")
            show_grid = st.checkbox("Show Grid", value=True)
            show_minor_grid = st.checkbox("Show Sub-grid (Minor Grid)", value=False)
            grid_direction = st.selectbox("Grid Direction:", ["x", "y", "both"], index=2)
            grid_line_style = st.selectbox("Grid Line Style:", ["-", "--", "-.", ":", "None"], index=0)
            grid_color = st.color_picker("Grid Line Color:", "#DDDDDD")
            grid_line_width = st.slider("Grid Line Width:", 0.5, 2.5, 1.0)

            # Enhanced Smoothing Settings
            st.subheader("Advanced Smoothing Settings")
            smoothing_method = st.selectbox(
                "Smoothing Method:",
                ["None", "Moving Average", "Gaussian", "Savitzky-Golay", "Median", "Combined", "Exponential Moving Average", "LOWESS", "Butterworth", "Fourier Transform"]
            )
            
            # Smoothing parameters based on selected method
            smoothing_params = {}
            if smoothing_method == "Moving Average":
                smoothing_params['window_size'] = st.slider(
                    "Window Size:",
                    3, 101, 5, step=2
                )
            elif smoothing_method == "Gaussian":
                smoothing_params['sigma'] = st.slider(
                    "Sigma (Blur Amount):",
                    0.1, 10.0, 2.0, step=0.1
                )
            elif smoothing_method == "Savitzky-Golay":
                smoothing_params['window_size'] = st.slider(
                    "Window Size:",
                    5, 101, 21, step=2
                )
                smoothing_params['poly_order'] = st.slider(
                    "Polynomial Order:",
                    1, 5, 3
                )
            elif smoothing_method == "Median":
                smoothing_params['kernel_size'] = st.slider(
                    "Kernel Size:",
                    3, 51, 5, step=2
                )
            elif smoothing_method == "Combined":
                smoothing_params['kernel_size'] = st.slider(
                    "Median Kernel Size:",
                    3, 51, 5, step=2
                )
                smoothing_params['window_size'] = st.slider(
                    "Savitzky-Golay Window:",
                    5, 101, 21, step=2
                )
                smoothing_params['poly_order'] = st.slider(
                    "Polynomial Order:",
                    1, 5, 3
                )
            elif smoothing_method == "Exponential Moving Average":
                smoothing_params['span'] = st.slider(
                    "Span (Smoothing Factor):",
                    1, 50, 10
                )
            elif smoothing_method == "LOWESS":
                smoothing_params['frac'] = st.slider(
                    "Fraction (Smoothing Proportion):",
                    0.01, 0.5, 0.1, step=0.01
                )
            elif smoothing_method == "Butterworth":
                smoothing_params['order'] = st.slider(
                    "Filter Order:",
                    1, 10, 3
                )
                smoothing_params['cutoff'] = st.slider(
                    "Cutoff Frequency:",
                    0.01, 0.5, 0.05, step=0.01
                )
            elif smoothing_method == "Fourier Transform":
                smoothing_params['keep_fraction'] = st.slider(
                    "Keep Fraction of Frequencies:",
                    0.01, 1.0, 0.1, step=0.01
                )


            # Vertical Marker Settings
            st.subheader("Vertical Marker Settings")
            marker_x_values = st.text_input("Enter X-axis Values for Markers (comma-separated):", value="")

            # DPI Setting
            dpi = st.selectbox("Select DPI for Download:", [100, 200, 300, 600], index=2)

            # Legend Customization
            st.subheader("Legend Customization")
            legend_font_size = st.slider("Font Size:", 8, 20, 10)
            legend_font_weight = st.selectbox("Font Weight:", ["Normal", "Bold"], index=0)
            legend_bg_color = st.color_picker("Background Color:", "#FFFFFF")
            legend_border_color = st.color_picker("Border Color:", "#000000")
            legend_border_width = st.slider("Border Width:", 0.5, 2.0, 1.0)
            legend_transparency = st.slider("Transparency (0 = fully opaque, 1 = fully transparent):", 0.0, 1.0, 0.5)
            legend_title = st.text_input("Legend Title:", value="")
            legend_location = st.selectbox("Legend Location:", 
                                           ["upper right", "upper center", "upper left", "center right", 
                                            "center", "center left", "lower right", 
                                            "lower center", "lower left"], index=1)
            legend_columns = st.selectbox("Legend Columns:", [1, 2, 3], index=1)

            # Line Style Settings
            st.subheader("Line Style Settings")
            style_settings = {}
            line_styles = {
                "Solid": "-", "Dashed": "--", "Dash-Dot": "-.", "Dotted": ":"
            }
            marker_styles = {
                "None": "", "Circle": "o", "Square": "s", "Star": "*", "Diamond": "D", "Triangle": "^", 
                "Pentagon": "p", "Hexagon": "H", "Plus": "+", "X": "x"
            }
            for y_column in y_columns:
                with st.expander(f"Line Style for '{y_column}'", expanded=False):
                    color = st.color_picker(f"Color for '{y_column}':", "#1f77b4", key=f"color_{y_column}")
                    line_style = st.selectbox("Line Style:", list(line_styles.keys()), key=f"ls_{y_column}")
                    marker_style = st.selectbox("Marker Style:", list(marker_styles.keys()), key=f"ms_{y_column}")
                    line_width = st.slider("Line Width:", 0.5, 5.0, 2.0, key=f"lw_{y_column}")
                    style_settings[y_column] = {
                        "color": color,
                        "line_style": line_styles[line_style],
                        "marker_style": marker_styles[marker_style],
                        "line_width": line_width,
                    }

        # Graph Output Section
        st.subheader("Graph Output")
        if st.button("Generate Graph"):
            fig, ax = plt.subplots()

            # Scaling for both X and Y axes (same options for both axes)
            def apply_scaling(df, column, scale_option, unit):
                if scale_option == "Hz → MHz":
                    df[column] = df[column] / 1e6
                    unit = "MHz"
                elif scale_option == "Hz → GHz":
                    df[column] = df[column] / 1e9
                    unit = "GHz"
                elif scale_option == "mm → cm":
                    df[column] = df[column] / 10
                    unit = "cm"
                elif scale_option == "mm → m":
                    df[column] = df[column] / 1000
                    unit = "m"
                elif scale_option == "cm → mm":
                    df[column] = df[column] * 10
                    unit = "mm"
                elif scale_option == "m → mm":
                    df[column] = df[column] * 1000
                    unit = "mm"
                elif scale_option == "Frequency → Wavelength":
                    df[column] = 3e8 / df[column]
                    unit = "Wavelength (m)"
                elif scale_option == "Wavelength → Frequency":
                    df[column] = 3e8 / df[column]
                    unit = "Frequency (Hz)"
                elif scale_option == "Linear → dB":
                    df[column] = 10 * np.log10(df[column])
                    unit = "dB"
                elif scale_option == "dB → Linear":
                    df[column] = 10**(df[column] / 10)
                    unit = "Linear"
                return df, unit

            # Apply scaling for X-axis
            df, x_unit = apply_scaling(df, x_column, scale_option, x_unit)

            # Apply scaling for Y-axis
            for col in y_columns:
                df, y_unit = apply_scaling(df, col, y_scale_option, y_unit)

            # Plot data with smoothing
            for col in y_columns:
                # Apply smoothing to the y-values
                y_values = apply_smoothing(df[col].values, smoothing_method, **smoothing_params)
                
                ax.plot(df[x_column], y_values,
                       label=col,
                       linestyle=style_settings[col]["line_style"],
                       marker=style_settings[col]["marker_style"],
                       color=style_settings[col]["color"],
                       linewidth=style_settings[col]["line_width"])

            # Axis limits and step size
            if x_min and x_max:
                ax.set_xticks(np.arange(x_min, x_max + x_step, x_step))
            if y_min and y_max:
                ax.set_yticks(np.arange(y_min, y_max + y_step, y_step))

            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)

            # Labels, title, and grid
            ax.set_xlabel(f"{x_label} ({x_unit})", fontsize=x_font_size)
            ax.set_ylabel(f"{y_label} ({y_unit})", fontsize=y_font_size)
            ax.set_title(title, fontsize=title_font_size, fontstyle=font_style, family=font_theme)
            ax.tick_params(axis='x', direction=x_tick_position)
            ax.tick_params(axis='y', direction=y_tick_position)

            if show_grid:
                ax.grid(True, linestyle=grid_line_style, linewidth=grid_line_width, color=grid_color, axis=grid_direction)
            if show_minor_grid:
                ax.minorticks_on()
                ax.grid(which='minor', linestyle=grid_line_style, linewidth=grid_line_width, color=grid_color)

            # Vertical markers
            if marker_x_values:
                x_vals = [float(val.strip()) for val in marker_x_values.split(",")]
                for val in x_vals:
                    ax.axvline(x=val, color="r", linestyle="--", linewidth=1.5)

            # Legend
            ax.legend(title=legend_title, fontsize=legend_font_size, title_fontsize=legend_font_size,
                      loc=legend_location, frameon=True, facecolor=legend_bg_color, edgecolor=legend_border_color,
                      framealpha=legend_transparency, borderpad=legend_border_width, ncol=legend_columns)

            # Show graph
            st.pyplot(fig)

            # Download button
            download_button(fig, "graph", "PNG", dpi)

    except Exception as e:
        st.error(f"Error: {str(e)}")