|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.cluster import KMeans |
|
from sklearn.metrics import pairwise_distances_argmin_min, r2_score |
|
import matplotlib.pyplot as plt |
|
import matplotlib.cm |
|
import io |
|
import os |
|
from PIL import Image |
|
|
|
|
|
EXAMPLE_DATA_DIR = "eg_data" |
|
EXAMPLE_FILES = { |
|
"cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), |
|
"cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"), |
|
"cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"), |
|
"policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), |
|
"pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"), |
|
"pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"), |
|
"pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"), |
|
} |
|
|
|
class Clusters: |
|
def __init__(self, loc_vars): |
|
self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) |
|
closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) |
|
|
|
rep_ids = pd.Series(data=(closest+1)) |
|
rep_ids.name = 'policy_id' |
|
rep_ids.index.name = 'cluster_id' |
|
self.rep_ids = rep_ids |
|
|
|
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] |
|
|
|
def agg_by_cluster(self, df, agg=None): |
|
"""Aggregate columns by cluster""" |
|
temp = df.copy() |
|
temp['cluster_id'] = self.kmeans.labels_ |
|
temp = temp.set_index('cluster_id') |
|
agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum" |
|
return temp.groupby(temp.index).agg(agg) |
|
|
|
def extract_reps(self, df): |
|
"""Extract the rows of representative policies""" |
|
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') |
|
temp.index.name = 'cluster_id' |
|
return temp.drop('policy_id', axis=1) |
|
|
|
def extract_and_scale_reps(self, df, agg=None): |
|
"""Extract and scale the rows of representative policies""" |
|
if agg: |
|
cols = df.columns |
|
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols}) |
|
|
|
extracted_df = self.extract_reps(df) |
|
mult.index = extracted_df.index |
|
return extracted_df.mul(mult) |
|
else: |
|
return self.extract_reps(df).mul(self.policy_count, axis=0) |
|
|
|
def compare(self, df, agg=None): |
|
"""Returns a multi-indexed Dataframe comparing actual and estimate""" |
|
source = self.agg_by_cluster(df, agg) |
|
target = self.extract_and_scale_reps(df, agg) |
|
return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()}) |
|
|
|
def compare_total(self, df, agg=None): |
|
"""Aggregate df by columns""" |
|
if agg: |
|
|
|
op = {c: (agg[c] if c in agg else 'sum') for c in df.columns} |
|
actual = df.agg(op) |
|
|
|
|
|
scaled_reps = self.extract_and_scale_reps(df, agg=op) |
|
|
|
|
|
estimate_agg_ops = {} |
|
for col_name, agg_type in op.items(): |
|
if agg_type == 'mean': |
|
|
|
estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean() |
|
else: |
|
estimate_agg_ops[col_name] = 'sum' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
estimate_scaled = self.extract_and_scale_reps(df, agg=op) |
|
|
|
final_estimate_ops = {} |
|
for col, method in op.items(): |
|
if method == 'mean': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reps_unscaled_for_mean = self.extract_reps(df) |
|
estimate_values = {} |
|
for c in df.columns: |
|
if op[c] == 'sum': |
|
estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum() |
|
elif op[c] == 'mean': |
|
weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum() |
|
total_weight = self.policy_count.sum() |
|
estimate_values[c] = weighted_sum / total_weight if total_weight else 0 |
|
estimate = pd.Series(estimate_values) |
|
|
|
else: |
|
final_estimate_ops[col] = 'sum' |
|
estimate = estimate_scaled.agg(final_estimate_ops) |
|
|
|
|
|
else: |
|
actual = df.sum() |
|
estimate = self.extract_and_scale_reps(df).sum() |
|
|
|
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1}) |
|
|
|
|
|
def plot_cashflows_comparison(cfs_list, cluster_obj, titles): |
|
"""Create cashflow comparison plots""" |
|
if not cfs_list or not cluster_obj or not titles: |
|
return None |
|
num_plots = len(cfs_list) |
|
if num_plots == 0: |
|
return None |
|
|
|
|
|
cols = 2 |
|
rows = (num_plots + cols - 1) // cols |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) |
|
axes = axes.flatten() |
|
|
|
for i, (df, title) in enumerate(zip(cfs_list, titles)): |
|
if i < len(axes): |
|
comparison = cluster_obj.compare_total(df) |
|
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) |
|
axes[i].set_xlabel('Time') |
|
axes[i].set_ylabel('Value') |
|
|
|
|
|
for j in range(i + 1, len(axes)): |
|
fig.delaxes(axes[j]) |
|
|
|
plt.tight_layout() |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
return img |
|
|
|
def plot_scatter_comparison(df_compare_output, title): |
|
"""Create scatter plot comparison from compare() output""" |
|
if df_compare_output is None or df_compare_output.empty: |
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15) |
|
ax.set_title(title) |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
return img |
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2: |
|
gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.") |
|
ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6) |
|
else: |
|
unique_levels = df_compare_output.index.get_level_values(1).unique() |
|
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels))) |
|
|
|
for item_level, color_val in zip(unique_levels, colors): |
|
subset = df_compare_output.xs(item_level, level=1) |
|
ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level) |
|
if len(unique_levels) > 1 and len(unique_levels) <=10: |
|
ax.legend(title=df_compare_output.index.names[1]) |
|
|
|
|
|
ax.set_xlabel('Actual') |
|
ax.set_ylabel('Estimate') |
|
ax.set_title(title) |
|
ax.grid(True) |
|
|
|
|
|
lims = [ |
|
np.min([ax.get_xlim(), ax.get_ylim()]), |
|
np.max([ax.get_xlim(), ax.get_ylim()]), |
|
] |
|
if lims[0] != lims[1]: |
|
ax.plot(lims, lims, 'r-', linewidth=0.5) |
|
ax.set_xlim(lims) |
|
ax.set_ylim(lims) |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
plt.close(fig) |
|
return img |
|
|
|
|
|
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path, |
|
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path): |
|
"""Main processing function - now accepts file paths""" |
|
try: |
|
|
|
cfs = pd.read_excel(cashflow_base_path, index_col=0) |
|
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0) |
|
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0) |
|
|
|
pol_data_full = pd.read_excel(policy_data_path, index_col=0) |
|
|
|
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth'] |
|
if all(col in pol_data_full.columns for col in required_cols): |
|
pol_data = pol_data_full[required_cols] |
|
else: |
|
|
|
gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}") |
|
pol_data = pol_data_full |
|
|
|
|
|
pvs = pd.read_excel(pv_base_path, index_col=0) |
|
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0) |
|
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0) |
|
|
|
cfs_list = [cfs, cfs_lapse50, cfs_mort15] |
|
|
|
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] |
|
|
|
results = {} |
|
|
|
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} |
|
|
|
|
|
cluster_cfs = Clusters(cfs) |
|
|
|
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs) |
|
|
|
|
|
|
|
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs) |
|
|
|
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs) |
|
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50) |
|
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15) |
|
|
|
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) |
|
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)') |
|
|
|
|
|
|
|
|
|
|
|
if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : |
|
loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min()) |
|
else: |
|
gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.") |
|
loc_vars_attrs = pol_data |
|
|
|
if not loc_vars_attrs.empty: |
|
cluster_attrs = Clusters(loc_vars_attrs) |
|
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs) |
|
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) |
|
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs) |
|
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) |
|
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)') |
|
|
|
|
|
else: |
|
results['attr_total_cf_base'] = pd.DataFrame() |
|
results['attr_policy_attrs_total'] = pd.DataFrame() |
|
results['attr_total_pv_base'] = pd.DataFrame() |
|
results['attr_cashflow_plot'] = None |
|
results['attr_scatter_cashflows_base'] = None |
|
|
|
|
|
|
|
cluster_pvs = Clusters(pvs) |
|
|
|
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs) |
|
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs) |
|
|
|
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs) |
|
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) |
|
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) |
|
|
|
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) |
|
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_data = {} |
|
|
|
|
|
if 'PV_NetCF' in pvs.columns: |
|
err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error'] |
|
err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error'] |
|
err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error'] |
|
error_data['CF Calib. (PV NetCF)'] = [ |
|
abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort) |
|
] |
|
else: |
|
error_data['CF Calib. (PV NetCF)'] = [ |
|
abs(cluster_cfs.compare_total(pvs)['error'].mean()), |
|
abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()), |
|
abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean()) |
|
] |
|
|
|
|
|
|
|
if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns: |
|
err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error'] |
|
err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error'] |
|
err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error'] |
|
error_data['Attr Calib. (PV NetCF)'] = [ |
|
abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort) |
|
] |
|
else: |
|
error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] |
|
|
|
|
|
|
|
if 'PV_NetCF' in pvs.columns: |
|
err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error'] |
|
err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error'] |
|
err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error'] |
|
error_data['PV Calib. (PV NetCF)'] = [ |
|
abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort) |
|
] |
|
else: |
|
error_data['PV Calib. (PV NetCF)'] = [ |
|
abs(cluster_pvs.compare_total(pvs)['error'].mean()), |
|
abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()), |
|
abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean()) |
|
] |
|
|
|
|
|
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']) |
|
|
|
fig_summary, ax_summary = plt.subplots(figsize=(10, 6)) |
|
summary_df.plot(kind='bar', ax=ax_summary, grid=True) |
|
ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)') |
|
ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow') |
|
ax_summary.tick_params(axis='x', rotation=0) |
|
plt.tight_layout() |
|
|
|
buf_summary = io.BytesIO() |
|
plt.savefig(buf_summary, format='png', dpi=100) |
|
buf_summary.seek(0) |
|
results['summary_plot'] = Image.open(buf_summary) |
|
plt.close(fig_summary) |
|
|
|
return results |
|
|
|
except FileNotFoundError as e: |
|
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.") |
|
return {"error": f"File not found: {e.filename}"} |
|
except KeyError as e: |
|
gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.") |
|
return {"error": f"Missing column: {e}"} |
|
except Exception as e: |
|
gr.Error(f"Error processing files: {str(e)}") |
|
return {"error": f"Error processing files: {str(e)}"} |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Cluster Model Points Analysis") as demo: |
|
gr.Markdown(""" |
|
# Cluster Model Points Analysis |
|
|
|
This application applies cluster analysis to model point selection for insurance portfolios. |
|
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods. |
|
|
|
**Required Files (Excel .xlsx):** |
|
- Cashflows - Base Scenario |
|
- Cashflows - Lapse Stress (+50%) |
|
- Cashflows - Mortality Stress (+15%) |
|
- Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth') |
|
- Present Values - Base Scenario |
|
- Present Values - Lapse Stress |
|
- Present Values - Mortality Stress |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Upload Files or Load Examples") |
|
|
|
load_example_btn = gr.Button("Load Example Data") |
|
|
|
with gr.Row(): |
|
cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"]) |
|
cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"]) |
|
cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"]) |
|
with gr.Row(): |
|
policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"]) |
|
pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"]) |
|
pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) |
|
with gr.Row(): |
|
pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) |
|
|
|
analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem(" Summary"): |
|
summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)") |
|
|
|
with gr.TabItem(" Cashflow Calibration"): |
|
gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables") |
|
with gr.Row(): |
|
cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") |
|
cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") |
|
cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") |
|
cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)") |
|
with gr.Accordion("Present Value Comparisons (Total)", open=False): |
|
with gr.Row(): |
|
cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total") |
|
cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total") |
|
cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total") |
|
|
|
with gr.TabItem(" Policy Attribute Calibration"): |
|
gr.Markdown("### Results: Using Policy Attributes as Calibration Variables") |
|
with gr.Row(): |
|
attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") |
|
attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") |
|
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") |
|
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)") |
|
with gr.Accordion("Present Value Comparisons (Total)", open=False): |
|
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total") |
|
|
|
with gr.TabItem(" Present Value Calibration"): |
|
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables") |
|
with gr.Row(): |
|
pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") |
|
pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") |
|
pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") |
|
pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)") |
|
with gr.Accordion("Present Value Comparisons (Total)", open=False): |
|
with gr.Row(): |
|
pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total") |
|
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total") |
|
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total") |
|
|
|
|
|
def get_all_output_components(): |
|
return [ |
|
summary_plot_output, |
|
|
|
cf_total_base_table_out, cf_policy_attrs_total_out, |
|
cf_cashflow_plot_out, cf_scatter_cashflows_base_out, |
|
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out, |
|
|
|
attr_total_cf_base_out, attr_policy_attrs_total_out, |
|
attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out, |
|
|
|
pv_total_cf_base_out, pv_policy_attrs_total_out, |
|
pv_cashflow_plot_out, pv_scatter_pvs_base_out, |
|
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out |
|
] |
|
|
|
|
|
def handle_analysis(f1, f2, f3, f4, f5, f6, f7): |
|
|
|
files = [f1, f2, f3, f4, f5, f6, f7] |
|
|
|
|
|
|
|
file_paths = [] |
|
for i, f_obj in enumerate(files): |
|
if f_obj is None: |
|
gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.") |
|
|
|
return [None] * len(get_all_output_components()) |
|
|
|
|
|
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): |
|
file_paths.append(f_obj.name) |
|
|
|
elif isinstance(f_obj, str): |
|
file_paths.append(f_obj) |
|
else: |
|
gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}") |
|
return [None] * len(get_all_output_components()) |
|
|
|
|
|
results = process_files(*file_paths) |
|
|
|
if "error" in results: |
|
|
|
return [None] * len(get_all_output_components()) |
|
|
|
return [ |
|
results.get('summary_plot'), |
|
|
|
results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'), |
|
results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'), |
|
results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'), |
|
|
|
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'), |
|
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'), |
|
|
|
results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'), |
|
results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'), |
|
results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort') |
|
] |
|
|
|
analyze_btn.click( |
|
handle_analysis, |
|
inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, |
|
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input], |
|
outputs=get_all_output_components() |
|
) |
|
|
|
|
|
def load_example_files(): |
|
|
|
missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)] |
|
if missing_files: |
|
gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.") |
|
return [None] * 7 |
|
|
|
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.") |
|
return [ |
|
EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"], |
|
EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"], |
|
EXAMPLE_FILES["pv_mort"] |
|
] |
|
|
|
load_example_btn.click( |
|
load_example_files, |
|
inputs=[], |
|
outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, |
|
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.path.exists(EXAMPLE_DATA_DIR): |
|
os.makedirs(EXAMPLE_DATA_DIR) |
|
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}") |
|
|
|
|
|
demo_app = create_interface() |
|
demo_app.launch() |