import gradio as gr import numpy as np import pandas as pd from sklearn.cluster import KMeans from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the final Gradio app logic import matplotlib.pyplot as plt import matplotlib.cm import io import os # Added for path joining from PIL import Image # Define the paths for example data EXAMPLE_DATA_DIR = "eg_data" EXAMPLE_FILES = { "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"), "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"), "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"), "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"), "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"), } class Clusters: def __init__(self, loc_vars): # Ensure loc_vars is not empty before fitting KMeans if loc_vars.empty: raise ValueError("Input data for KMeans (loc_vars) is empty.") if loc_vars.isnull().all().all(): raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.") self.kmeans = KMeans(n_clusters=min(1000, len(loc_vars)), random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes rep_ids.name = 'policy_id' rep_ids.index.name = 'cluster_id' self.rep_ids = rep_ids self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] def agg_by_cluster(self, df, agg=None): temp = df.copy() temp['cluster_id'] = self.kmeans.labels_ temp = temp.set_index('cluster_id') # Ensure agg is a dictionary if not None if agg is not None and not isinstance(agg, dict): # Assuming if agg is not a dict, it's the default "sum" for all, which is handled by else. # This case might need specific handling if agg can be other types. # For now, if it's not a dict, treat as if no specific agg ops were given for columns. agg_ops = {col: "sum" for col in temp.columns} # Default to sum if agg format is unexpected elif isinstance(agg, dict): agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} else: # agg is None agg_ops = "sum" # Pandas groupby will apply sum to all numeric columns return temp.groupby(temp.index).agg(agg_ops) def extract_reps(self, df): temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') temp.index.name = 'cluster_id' return temp.drop('policy_id', axis=1) def extract_and_scale_reps(self, df, agg=None): extracted_df = self.extract_reps(df) if extracted_df.empty: return extracted_df # Return empty if no representatives if agg and isinstance(agg, dict): # mult should be a Series aligned with extracted_df's columns for element-wise multiplication after selection # This part of the logic seems to intend to scale rows based on policy_count for 'sum' aggs # and leave 'mean' aggs as is (to be weighted later). # The original code created a DataFrame `mult` then did .mul(mult). # A more direct approach for scaling rows: scaled_df = extracted_df.copy() for c in extracted_df.columns: if agg.get(c, 'sum') == 'sum': # Default to 'sum' if column not in agg scaled_df[c] = extracted_df[c].mul(self.policy_count, axis=0) # else (it's 'mean'), do not scale by policy_count here. return scaled_df else: # Default: scale all columns by policy_count (as if for sum) return extracted_df.mul(self.policy_count, axis=0) def compare(self, df, agg=None): source = self.agg_by_cluster(df, agg) target = self.extract_and_scale_reps(df, agg) # This target needs to be aggregated like source # The target from extract_and_scale_reps is already scaled per cluster for 'sum' ops. # For 'mean' ops, it's the representative value. # We need to sum up the 'sum' columns and calculate weighted average for 'mean' columns. if agg and isinstance(agg, dict): agg_ops_for_target = {} for col, method in agg.items(): if method == 'sum': agg_ops_for_target[col] = 'sum' elif method == 'mean': # For mean, we need sum(val*count)/sum(count). # extract_and_scale_reps DID NOT scale mean columns by policy_count. # So, target[col] has rep values. We need to weight them. # This is better handled in compare_total. Here, target is per-cluster. # This function compares per-cluster values BEFORE final aggregation. # So target should represent aggregated values per cluster. pass # 'sum' columns are scaled, 'mean' columns are rep values else: # all sum pass # target is already scaled by policy_count, so it's the sum per cluster # This function is for per-cluster comparison, not total. # The 'target' from extract_and_scale_reps already has the representative values scaled by policy_count for sum-like aggregations. # If a column is meant for 'mean', it's just the representative value. # This 'compare' function might be misinterpreting 'target' if 'agg' has 'mean'. # The original notebook's compare function: # source = self.agg_by_cluster(df, agg) # Actual sums/means per cluster # target = self.extract_and_scale_reps(df, agg) # Rep values, scaled by count if 'sum', unscaled if 'mean' # This structure implies 'target' might not be directly comparable if 'mean' is involved without further processing. # However, the scatter plots it generates plot these per-cluster values. # For 'sum' variables, target is an estimate of the cluster total. # For 'mean' variables, target is the rep's value (estimate of cluster mean). return pd.DataFrame({'actual': source.stack(), 'estimate': target.stack()}) def compare_total(self, df, agg=None): """Aggregate df by columns and compare actual vs estimate totals.""" if df.empty: return pd.DataFrame(columns=['actual', 'estimate', 'error']) # Determine aggregation operations for each column op_for_actual = {} if isinstance(agg, dict): for c in df.columns: op_for_actual[c] = agg.get(c, 'sum') # Default to 'sum' if not in agg else: # agg is None or not a dict, apply sum to all for c in df.columns: if pd.api.types.is_numeric_dtype(df[c]): op_for_actual[c] = 'sum' # else: non-numeric columns will be ignored by df.agg if op not specified actual = df.agg(op_for_actual) actual = actual.dropna() # Remove non-numeric results if any # Calculate estimate reps_values = self.extract_reps(df) # Get raw representative values (one per cluster) if reps_values.empty: # No representatives found estimate = pd.Series(index=actual.index, dtype=float) # Empty or NaN series else: estimate_values = {} for col_name in actual.index: # Iterate over columns that had a valid actual aggregation col_op = op_for_actual.get(col_name, 'sum') if col_name not in reps_values.columns: # Should not happen if df columns match estimate_values[col_name] = np.nan continue rep_col_values = reps_values[col_name] if col_op == 'sum': # Estimate for sum is sum of (representative_value * policy_count_for_its_cluster) estimate_values[col_name] = (rep_col_values * self.policy_count).sum() elif col_op == 'mean': # Estimate for mean is weighted average: sum(rep_value * policy_count) / sum(policy_count) weighted_sum = (rep_col_values * self.policy_count).sum() total_weight = self.policy_count.sum() estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan else: # Should not happen given op_for_actual logic estimate_values[col_name] = np.nan estimate = pd.Series(estimate_values, index=actual.index) # Align with actual's index # Calculate error # Align actual and estimate to ensure they cover the same items for error calculation actual_aligned, estimate_aligned = actual.align(estimate, join='inner') error = pd.Series(index=actual_aligned.index, dtype=float) # Valid division where actual is not zero and not NaN valid_mask = (actual_aligned != 0) & (~actual_aligned.isna()) error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1 # Where actual is zero (and not NaN) actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna()) # If estimate is also zero, error is 0 error[actual_zero_mask & (estimate_aligned == 0)] = 0 # If estimate is non-zero and actual is zero, error is effectively infinite error[actual_zero_mask & (estimate_aligned != 0)] = np.inf # Replace any infinities with NaN for cleaner results (e.g., for .mean()) error = error.replace([np.inf, -np.inf], np.nan) result_df = pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error}) return result_df def plot_cashflows_comparison(cfs_list, cluster_obj, titles): if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No data for cashflow comparison plot.", ha='center', va='center') buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img num_plots = len(cfs_list) cols = 2 rows = (num_plots + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) axes = axes.flatten() plot_made = False for i, (df_cf, title) in enumerate(zip(cfs_list, titles)): if i < len(axes): if df_cf is None or df_cf.empty: axes[i].text(0.5,0.5, f"No data for {title}", ha='center', va='center') axes[i].set_title(title) continue comparison = cluster_obj.compare_total(df_cf) # Default is sum for all columns if not comparison.empty and 'actual' in comparison and 'estimate' in comparison: comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) axes[i].set_xlabel('Time') axes[i].set_ylabel('Value') plot_made = True else: axes[i].text(0.5,0.5, f"Could not generate comparison for {title}", ha='center', va='center') axes[i].set_title(title) for j in range(i + 1, len(axes)): # Hide unused subplots fig.delaxes(axes[j]) if not plot_made: # If no plots were actually made (e.g. all data was empty) plt.close(fig) # Close the figure fig, ax = plt.subplots() # Create a new one for the message ax.text(0.5, 0.5, "Insufficient data for any cashflow plots.", ha='center', va='center') plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100) buf.seek(0) img = Image.open(buf) plt.close(fig) return img def plot_scatter_comparison(df_compare_output, title): if df_compare_output is None or df_compare_output.empty: fig, ax = plt.subplots(figsize=(10,6)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title) buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img fig, ax = plt.subplots(figsize=(10, 6)) if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2: # This case indicates df_compare_output is not from cluster_obj.compare() as expected ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6) else: unique_levels = df_compare_output.index.get_level_values(1).unique() colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels))) for item_level, color_val in zip(unique_levels, colors): subset = df_compare_output.xs(item_level, level=1) ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string if len(unique_levels) > 1 and len(unique_levels) <=10: ax.legend(title=df_compare_output.index.names[1]) ax.set_xlabel('Actual') ax.set_ylabel('Estimate') ax.set_title(title) ax.grid(True) try: current_xlim = ax.get_xlim() current_ylim = ax.get_ylim() lims = [ np.nanmin([current_xlim, current_ylim]), np.nanmax([current_xlim, current_ylim]), ] if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]): ax.plot(lims, lims, 'r-', linewidth=0.5) ax.set_xlim(lims) ax.set_ylim(lims) except Exception: # Catch errors if lims are problematic (e.g. all NaNs) pass buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100) buf.seek(0) img = Image.open(buf) plt.close(fig) return img def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path, policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path): results = {} try: cfs = pd.read_excel(cashflow_base_path, index_col=0) cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0) cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0) pol_data_full = pd.read_excel(policy_data_path, index_col=0) required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth'] missing_policy_cols = [col for col in required_cols if col not in pol_data_full.columns] if missing_policy_cols: gr.Warning(f"Policy data is missing required columns: {', '.join(missing_policy_cols)}. Analysis may be affected.") pol_data = pol_data_full # Use what's available else: pol_data = pol_data_full[required_cols] pvs = pd.read_excel(pv_base_path, index_col=0) pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0) pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0) cfs_list = [cfs, cfs_lapse50, cfs_mort15] scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # --- 1. Cashflow Calibration --- gr.Info("Starting Cashflow Calibration...") if cfs.empty: gr.Warning("Base cashflow data is empty for Cashflow Calibration.") cluster_cfs = Clusters(cfs) results['cf_total_base_table'] = cluster_cfs.compare_total(cfs) results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg) results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs) results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50) results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15) results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)') gr.Info("Cashflow Calibration Done.") # --- 2. Policy Attribute Calibration --- gr.Info("Starting Policy Attribute Calibration...") if pol_data.empty : gr.Warning("Policy data is empty. Skipping Policy Attribute Calibration.") loc_vars_attrs = pd.DataFrame() # Empty dataframe else: pol_data_min = pol_data.min() pol_data_range = pol_data.max() - pol_data_min # Avoid division by zero if a column has no variance (all values are the same) if (pol_data_range == 0).any(): gr.Warning("Some policy attributes have no variance (all values are the same). Standardization might be affected.") # For columns with zero range, standardized value becomes 0 or NaN depending on pandas version. # A common approach is to set them to 0 or handle them separately. # Here, we proceed, but pandas might produce NaNs if (val - min) / 0 occurs. # Let's ensure range is not zero for division: pol_data_range[pol_data_range == 0] = 1 # Avoid division by zero, effectively making constant columns 0 after (x-min)/1 loc_vars_attrs = (pol_data - pol_data_min) / pol_data_range loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle any NaNs from perfect constant columns if not loc_vars_attrs.empty: cluster_attrs = Clusters(loc_vars_attrs) results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs) results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg) results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs) results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)') else: results.update({k: pd.DataFrame() for k in ['attr_total_cf_base', 'attr_policy_attrs_total', 'attr_total_pv_base']}) results.update({k: None for k in ['attr_cashflow_plot', 'attr_scatter_cashflows_base']}) gr.Info("Policy Attribute Calibration Done.") # --- 3. Present Value Calibration --- gr.Info("Starting Present Value Calibration...") if pvs.empty: gr.Warning("Base Present Value data is empty for PV Calibration.") cluster_pvs = Clusters(pvs) results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs) results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg) results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs) results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)') gr.Info("Present Value Calibration Done.") # --- Summary Comparison Plot Data --- gr.Info("Generating Summary Plot...") error_data = {} pv_col_name = 'PV_NetCF' # Target column for summary for calib_prefix, cluster_obj, calib_name_display in [ ('CF Calib.', cluster_cfs, "CF Calib."), ('Attr Calib.', globals().get('cluster_attrs'), "Attr Calib."), ('PV Calib.', cluster_pvs, "PV Calib.")]: current_calib_errors = [] if cluster_obj is None and calib_prefix == 'Attr Calib.': # Attr calib might be skipped current_calib_errors = [np.nan, np.nan, np.nan] else: for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]: if pv_df_scenario.empty: current_calib_errors.append(np.nan) continue comp_total_df = cluster_obj.compare_total(pv_df_scenario) if pv_col_name in comp_total_df.index: error_val = comp_total_df.loc[pv_col_name, 'error'] elif not comp_total_df.empty and 'error' in comp_total_df.columns: error_val = comp_total_df['error'].mean() # Fallback if calib_prefix == 'CF Calib.' and pv_df_scenario is pvs: # Only warn once per type if fallback gr.Warning(f"'{pv_col_name}' not found for summary plot. Using mean error of all PV columns instead for {calib_name_display}.") else: # comp_total_df is empty or no 'error' column error_val = np.nan current_calib_errors.append(abs(error_val)) error_data[calib_name_display] = current_calib_errors summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']) fig_summary, ax_summary = plt.subplots(figsize=(10, 6)) plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}' if summary_df.isnull().all().all(): ax_summary.text(0.5, 0.5, f"Error data for summary is N/A.\nCheck input PV files for '{pv_col_name}' column and valid numeric data.", ha='center', va='center', transform=ax_summary.transAxes, wrap=True) ax_summary.set_title(plot_title) elif summary_df.empty: ax_summary.text(0.5, 0.5, "No summary data to plot.", ha='center', va='center') ax_summary.set_title(plot_title) else: summary_df.plot(kind='bar', ax=ax_summary, grid=True) ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)') ax_summary.set_title(plot_title) ax_summary.tick_params(axis='x', rotation=0) plt.tight_layout() buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=100); buf_summary.seek(0) results['summary_plot'] = Image.open(buf_summary) plt.close(fig_summary) gr.Info("All processing complete.") return results except FileNotFoundError as e: gr.Error(f"File not found: {e.filename}. Ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded correctly.") return {"error": f"File not found: {e.filename}"} except ValueError as e: # Catch specific errors like empty data for KMeans gr.Error(f"Data validation error: {str(e)}") return {"error": f"Data error: {str(e)}"} except KeyError as e: gr.Error(f"A required column is missing: {e}. Please check data formats, especially index columns and expected data columns like 'PV_NetCF'.") return {"error": f"Missing column: {e}"} except Exception as e: gr.Error(f"An unexpected error occurred during processing: {str(e)}") import traceback traceback.print_exc() # Print full traceback to console for debugging return {"error": f"Processing error: {str(e)}"} def create_interface(): with gr.Blocks(title="Cluster Model Points Analysis") as demo: gr.Markdown(""" # Cluster Model Points Analysis This application applies k-means cluster analysis to select representative model points from an insurance portfolio. Upload your Excel files or use the example data to analyze results based on different calibration variable choices. **Required Excel (.xlsx) Files:** - Cashflows - Base Scenario - Cashflows - Lapse Stress (+50%) - Cashflows - Mortality Stress (+15%) - Policy Data (must include 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth', and an index column for `policy_id`) - Present Values - Base Scenario (ideally with a 'PV_NetCF' column and an index column for `policy_id`) - Present Values - Lapse Stress (same structure as Base PV) - Present Values - Mortality Stress (same structure as Base PV) """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📂 Upload Files or Load Examples") load_example_btn = gr.Button("Load Example Data", icon="💾") with gr.Row(): cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"]) cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"]) cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"]) with gr.Row(): policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"]) pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"]) pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) with gr.Row(): pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) span_dummy = gr.File(visible=False) # For layout balance if needed span_dummy2 = gr.File(visible=False) analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", scale=1) with gr.Tabs(): with gr.TabItem("📊 Summary"): summary_plot_output = gr.Image(label="Calibration Methods Comparison") with gr.TabItem("💸 Cashflow Calibration"): gr.Markdown("### Results: Using Annual Cashflows (Base) as Calibration Variables") with gr.Row(): cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") cf_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)") with gr.Accordion("Present Value Comparisons (Totals)", open=False): with gr.Row(): cf_pv_total_base_out = gr.Dataframe(label="PVs - Base", wrap=True) cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True) cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True) with gr.TabItem("👤 Policy Attribute Calibration"): gr.Markdown("### Results: Using Policy Attributes as Calibration Variables") with gr.Row(): attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") attr_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)") with gr.Accordion("Present Value Comparisons (Totals)", open=False): attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario", wrap=True) with gr.TabItem("💰 Present Value Calibration"): gr.Markdown("### Results: Using Present Values (Base) as Calibration Variables") with gr.Row(): pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True) pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True) pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)") pv_scatter_pvs_base_out = gr.Image(label="Scatter: Per-Cluster PVs (Base)") with gr.Accordion("Present Value Comparisons (Totals)", open=False): with gr.Row(): pv_total_pv_base_out = gr.Dataframe(label="PVs - Base", wrap=True) pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True) pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True) output_components = [ summary_plot_output, cf_total_base_table_out, cf_policy_attrs_total_out, cf_cashflow_plot_out, cf_scatter_cashflows_base_out, cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out, attr_total_cf_base_out, attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out, pv_total_cf_base_out, pv_policy_attrs_total_out, pv_cashflow_plot_out, pv_scatter_pvs_base_out, pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out ] def handle_analysis_click(f1, f2, f3, f4, f5, f6, f7): all_files_present = all(f is not None for f in [f1, f2, f3, f4, f5, f6, f7]) if not all_files_present: gr.Warning("Not all files have been provided. Please upload all 7 files or load example data.") return [None] * len(output_components) # Return Nones for all output components # file objects (f1, etc.) from gr.File are TemporaryFileWrapper or string paths if loaded by example file_paths = [] for f_obj in [f1, f2, f3, f4, f5, f6, f7]: if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Uploaded file file_paths.append(f_obj.name) elif isinstance(f_obj, str): # Path from example load file_paths.append(f_obj) else: # Should not happen if files are present gr.Error(f"Invalid file input: {f_obj}. Please re-upload or reload examples.") return [None] * len(output_components) analysis_results = process_files(*file_paths) if "error" in analysis_results: # Error handled and displayed by process_files return [None] * len(output_components) # Map results to output components return [ analysis_results.get('summary_plot'), analysis_results.get('cf_total_base_table'), analysis_results.get('cf_policy_attrs_total'), analysis_results.get('cf_cashflow_plot'), analysis_results.get('cf_scatter_cashflows_base'), analysis_results.get('cf_pv_total_base'), analysis_results.get('cf_pv_total_lapse'), analysis_results.get('cf_pv_total_mort'), analysis_results.get('attr_total_cf_base'), analysis_results.get('attr_policy_attrs_total'), analysis_results.get('attr_cashflow_plot'), analysis_results.get('attr_scatter_cashflows_base'), analysis_results.get('attr_total_pv_base'), analysis_results.get('pv_total_cf_base'), analysis_results.get('pv_policy_attrs_total'), analysis_results.get('pv_cashflow_plot'), analysis_results.get('pv_scatter_pvs_base'), analysis_results.get('pv_total_pv_base'), analysis_results.get('pv_total_pv_lapse'), analysis_results.get('pv_total_pv_mort') ] analyze_btn.click( handle_analysis_click, inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input], outputs=output_components ) input_file_components = [ cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input ] def load_example_files_action(): missing_example_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)] if missing_example_files: gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_example_files)}. Please ensure they exist.") return [None] * len(input_file_components) gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.") return [ EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"], EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"], EXAMPLE_FILES["pv_mort"] ] load_example_btn.click(load_example_files_action, inputs=[], outputs=input_file_components) return demo if __name__ == "__main__": if not os.path.exists(EXAMPLE_DATA_DIR): try: os.makedirs(EXAMPLE_DATA_DIR) print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.") print(f"Expected files: {list(EXAMPLE_FILES.keys())}") except OSError as e: print(f"Error creating directory {EXAMPLE_DATA_DIR}: {e}. Please create it manually.") print("Starting Gradio application...") print(f"Note: Ensure your example Excel files are placed in the '{os.getcwd()}{os.sep}{EXAMPLE_DATA_DIR}' folder.") print(f"Required policy data columns: 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' (and an index col).") print(f"Recommended PV files column for summary: 'PV_NetCF' (and an index col).") demo_app = create_interface() demo_app.launch()