alidenewade's picture
Update app.py
e82ad24 verified
raw
history blame
33.1 kB
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min, r2_score
import matplotlib.pyplot as plt
import matplotlib.cm
import io
import os # Added for path joining
from PIL import Image
# Define the paths for example data
EXAMPLE_DATA_DIR = "eg_data"
EXAMPLE_FILES = {
    "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
    "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
    "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
    "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
    "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
    "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
    "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
}
class Clusters:
    def __init__(self, loc_vars):
        self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
        closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
       
        rep_ids = pd.Series(data=(closest+1))  # 0-based to 1-based indexes
        rep_ids.name = 'policy_id'
        rep_ids.index.name = 'cluster_id'
        self.rep_ids = rep_ids
       
        self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
    def agg_by_cluster(self, df, agg=None):
        """Aggregate columns by cluster"""
        temp = df.copy()
        temp['cluster_id'] = self.kmeans.labels_
        temp = temp.set_index('cluster_id')
        agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
        return temp.groupby(temp.index).agg(agg)
    def extract_reps(self, df):
        """Extract the rows of representative policies"""
        temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
        temp.index.name = 'cluster_id'
        return temp.drop('policy_id', axis=1)
    def extract_and_scale_reps(self, df, agg=None):
        """Extract and scale the rows of representative policies"""
        if agg:
            cols = df.columns
            mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
            # Ensure mult has same index as extract_reps(df) for proper alignment
            extracted_df = self.extract_reps(df)
            mult.index = extracted_df.index
            return extracted_df.mul(mult)
        else:
            return self.extract_reps(df).mul(self.policy_count, axis=0)
    def compare(self, df, agg=None):
        """Returns a multi-indexed Dataframe comparing actual and estimate"""
        source = self.agg_by_cluster(df, agg)
        target = self.extract_and_scale_reps(df, agg)
        return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
    def compare_total(self, df, agg=None):
        """Aggregate df by columns"""
        if agg:
            # cols = df.columns # Not used
            op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
            actual = df.agg(op)
           
            # For estimate, ensure aggregation ops are correctly applied *after* scaling
            scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
           
            # Corrected aggregation for estimate when 'mean' is involved
            estimate_agg_ops = {}
            for col_name, agg_type in op.items():
                if agg_type == 'mean':
                    # Weighted average for mean columns
                    estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
                else: # 'sum'
                    estimate_agg_ops[col_name] = 'sum'
           
            # Need to handle the case where extract_and_scale_reps already applied scaling for sum
            # The logic in extract_and_scale_reps is:
            # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
            # This means 'mean' columns are NOT multiplied by policy_count initially.
           
            # Let's re-think the estimate aggregation for 'mean'
            estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
           
            final_estimate_ops = {}
            for col, method in op.items():
                if method == 'mean':
                    # For mean, we need the sum of (value * policy_count) / sum(policy_count)
                    # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
                    # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
                    # sum them up, and divide by total policy_count.
                    # However, the current extract_and_scale_reps scales 'mean' columns by 1.
                    # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
                    # Let's try to be more direct:
                    # Get the representative policies (unscaled for mean columns)
                    reps_unscaled_for_mean = self.extract_reps(df)
                    estimate_values = {}
                    for c in df.columns:
                        if op[c] == 'sum':
                           estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
                        elif op[c] == 'mean':
                           weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
                           total_weight = self.policy_count.sum()
                           estimate_values[c] = weighted_sum / total_weight if total_weight else 0
                    estimate = pd.Series(estimate_values)
                else: # original 'sum' logic for all columns
                    final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
                    estimate = estimate_scaled.agg(final_estimate_ops)
        else: # Original logic if no agg is specified (all sum)
            actual = df.sum()
            estimate = self.extract_and_scale_reps(df).sum()
       
        return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
    """Create cashflow comparison plots"""
    if not cfs_list or not cluster_obj or not titles:
        return None # Or a placeholder image
    num_plots = len(cfs_list)
    if num_plots == 0:
        return None
    # Determine subplot layout (e.g., 2x2 or adapt)
    cols = 2
    rows = (num_plots + cols - 1) // cols
   
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
    axes = axes.flatten()
   
    for i, (df, title) in enumerate(zip(cfs_list, titles)):
        if i < len(axes):
            comparison = cluster_obj.compare_total(df)
            comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
            axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
            axes[i].set_ylabel('Value')
   
    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])
       
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig) # Ensure figure is closed
    return img
def plot_scatter_comparison(df_compare_output, title):
    """Create scatter plot comparison from compare() output"""
    if df_compare_output is None or df_compare_output.empty:
        # Create a blank plot with a message
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
        ax.set_title(title)
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        img = Image.open(buf)
        plt.close(fig)
        return img
    fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
   
    if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
         gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
         ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
    else:
        unique_levels = df_compare_output.index.get_level_values(1).unique()
        colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
       
        for item_level, color_val in zip(unique_levels, colors):
            subset = df_compare_output.xs(item_level, level=1)
            ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
        if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
            ax.legend(title=df_compare_output.index.names[1])
    ax.set_xlabel('Actual')
    ax.set_ylabel('Estimate')
    ax.set_title(title)
    ax.grid(True)
   
    # Draw identity line
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]
    if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
      ax.plot(lims, lims, 'r-', linewidth=0.5)
      ax.set_xlim(lims)
      ax.set_ylim(lims)
   
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100)
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    return img
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
                  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
    """Main processing function - now accepts file paths"""
    try:
        # Read uploaded files using paths
        cfs = pd.read_excel(cashflow_base_path, index_col=0)
        cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
        cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
       
        pol_data_full = pd.read_excel(policy_data_path, index_col=0)
        # Ensure the correct columns are selected for pol_data
        required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
        if all(col in pol_data_full.columns for col in required_cols):
            pol_data = pol_data_full[required_cols]
        else:
            # Fallback or error if columns are missing. For now, try to use as is or a subset.
            gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
            pol_data = pol_data_full
        pvs = pd.read_excel(pv_base_path, index_col=0)
        pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
        pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
       
        cfs_list = [cfs, cfs_lapse50, cfs_mort15]
        # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
        scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
       
        results = {}
       
        mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
        # --- 1. Cashflow Calibration ---
        cluster_cfs = Clusters(cfs)
       
        results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
        # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
        # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
        results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
       
        results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
        results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
        results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
       
        results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
        results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
        # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
        # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
        # --- 2. Policy Attribute Calibration ---
        # Standardize policy attributes
        if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
             loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
        else:
            gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
            loc_vars_attrs = pol_data # or handle as an error/skip
       
        if not loc_vars_attrs.empty:
            cluster_attrs = Clusters(loc_vars_attrs)
            results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
            results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
            results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
            results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
            results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
            # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
        else: # Fill with None if skipped
            results['attr_total_cf_base'] = pd.DataFrame()
            results['attr_policy_attrs_total'] = pd.DataFrame()
            results['attr_total_pv_base'] = pd.DataFrame()
            results['attr_cashflow_plot'] = None
            results['attr_scatter_cashflows_base'] = None
        # --- 3. Present Value Calibration ---
        cluster_pvs = Clusters(pvs)
       
        results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
        results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
       
        results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
        results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
        results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
       
        results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
        results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
        # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
        # --- Summary Comparison Plot Data ---
        # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
        # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
        # For simplicity, using mean of the 'error' column from compare_total for key metrics
       
        error_data = {}
       
        # Cashflow Calibration Errors
        if 'PV_NetCF' in pvs.columns:
            err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
            err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
            err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
            error_data['CF Calib. (PV NetCF)'] = [
                abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
            ]
        else: # Fallback if PV_NetCF is not present
            error_data['CF Calib. (PV NetCF)'] = [
                abs(cluster_cfs.compare_total(pvs)['error'].mean()),
                abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
                abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
            ]
        # Policy Attribute Calibration Errors
        if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
            err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
            err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
            err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
            error_data['Attr Calib. (PV NetCF)'] = [
                abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
            ]
        else:
             error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
        # Present Value Calibration Errors
        if 'PV_NetCF' in pvs.columns:
            err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
            err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
            err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
            error_data['PV Calib. (PV NetCF)'] = [
                abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
            ]
        else:
            error_data['PV Calib. (PV NetCF)'] = [
                abs(cluster_pvs.compare_total(pvs)['error'].mean()),
                abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
                abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
            ]
       
        # Create Summary Plot
        summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
       
        fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
        summary_df.plot(kind='bar', ax=ax_summary, grid=True)
        ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
        ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
        ax_summary.tick_params(axis='x', rotation=0)
        plt.tight_layout()
       
        buf_summary = io.BytesIO()
        plt.savefig(buf_summary, format='png', dpi=100)
        buf_summary.seek(0)
        results['summary_plot'] = Image.open(buf_summary)
        plt.close(fig_summary)
       
        return results
       
    except FileNotFoundError as e:
        gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
        return {"error": f"File not found: {e.filename}"}
    except KeyError as e:
        gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
        return {"error": f"Missing column: {e}"}
    except Exception as e:
        gr.Error(f"Error processing files: {str(e)}")
        return {"error": f"Error processing files: {str(e)}"}
def create_interface():
    with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
        gr.Markdown("""
        # Cluster Model Points Analysis
       
        This application applies cluster analysis to model point selection for insurance portfolios.
        Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
       
        **Required Files (Excel .xlsx):**
        - Cashflows - Base Scenario
        - Cashflows - Lapse Stress (+50%)
        - Cashflows - Mortality Stress (+15%)
        - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
        - Present Values - Base Scenario
        - Present Values - Lapse Stress
        - Present Values - Mortality Stress
        """)
       
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Upload Files or Load Examples")
               
                load_example_btn = gr.Button("Load Example Data")
               
                with gr.Row():
                    cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
                    cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
                    cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
                with gr.Row():
                    policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
                    pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
                    pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
                with gr.Row():
                    pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
               
                analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
       
        with gr.Tabs():
            with gr.TabItem(" Summary"):
                summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
           
            with gr.TabItem(" Cashflow Calibration"):
                gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
                with gr.Row():
                    cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                    with gr.Row():
                        cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
                        cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
                        cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
           
            with gr.TabItem(" Policy Attribute Calibration"):
                gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
                with gr.Row():
                    attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                     attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
            with gr.TabItem(" Present Value Calibration"):
                gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
                with gr.Row():
                    pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                    with gr.Row():
                        pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
                        pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
                        pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
        # --- Helper function to prepare outputs ---
        def get_all_output_components():
            return [
                summary_plot_output,
                # Cashflow Calib Outputs
                cf_total_base_table_out, cf_policy_attrs_total_out,
                cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
                cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
                # Attribute Calib Outputs
                attr_total_cf_base_out, attr_policy_attrs_total_out,
                attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
                # PV Calib Outputs
                pv_total_cf_base_out, pv_policy_attrs_total_out,
                pv_cashflow_plot_out, pv_scatter_pvs_base_out,
                pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
            ]
       
        # --- Action for Analyze Button ---
        def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
            # Ensure all files are provided (either by upload or example load)
            files = [f1, f2, f3, f4, f5, f6, f7]
            # Gradio File objects have a .name attribute for the temp path
            # If they are already strings (from example load), they are paths
           
            file_paths = []
            for i, f_obj in enumerate(files):
                if f_obj is None:
                    gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
                    # Return Nones for all output components
                    return [None] * len(get_all_output_components())
               
                # If f_obj is a Gradio FileData object (from direct upload)
                if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
                    file_paths.append(f_obj.name)
                # If f_obj is already a string path (from example load)
                elif isinstance(f_obj, str):
                     file_paths.append(f_obj)
                else:
                    gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
                    return [None] * len(get_all_output_components())
            results = process_files(*file_paths)
           
            if "error" in results:
                # Error already displayed by process_files or here
                return [None] * len(get_all_output_components())
           
            return [
                results.get('summary_plot'),
                # CF Calib
                results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
                results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
                results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
                # Attr Calib
                results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
                results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
                # PV Calib
                results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
                results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
                results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
            ]
        analyze_btn.click(
            handle_analysis,
            inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
                    policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
            outputs=get_all_output_components()
        )
        # --- Action for Load Example Data Button ---
        def load_example_files():
            # Check if all example files exist
            missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
            if missing_files:
                gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
                return [None] * 7 # Return Nones for all file inputs
           
            gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
            return [
                EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
                EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
                EXAMPLE_FILES["pv_mort"]
            ]
        load_example_btn.click(
            load_example_files,
            inputs=[],
            outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
                     policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
        )
           
    return demo
if __name__ == "__main__":
    # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
    if not os.path.exists(EXAMPLE_DATA_DIR):
        os.makedirs(EXAMPLE_DATA_DIR)
        print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
        # You might want to add dummy files here for basic testing if the real files aren't present
        # For example:
        # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
        # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
        # It's better to instruct the user to add the actual files.
        print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
    demo_app = create_interface()
    demo_app.launch()