Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from sklearn.cluster import KMeans | |
from sklearn.metrics import pairwise_distances_argmin_min, r2_score | |
import matplotlib.pyplot as plt | |
import matplotlib.cm | |
import io | |
import os # Added for path joining | |
from PIL import Image | |
# Define the paths for example data | |
EXAMPLE_DATA_DIR = "eg_data" | |
EXAMPLE_FILES = { | |
"cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), | |
"cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"), | |
"cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"), | |
"policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), | |
"pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"), | |
"pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"), | |
"pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"), | |
} | |
class Clusters: | |
def __init__(self, loc_vars): | |
self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) | |
closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) | |
rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes | |
rep_ids.name = 'policy_id' | |
rep_ids.index.name = 'cluster_id' | |
self.rep_ids = rep_ids | |
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] | |
def agg_by_cluster(self, df, agg=None): | |
"""Aggregate columns by cluster""" | |
temp = df.copy() | |
temp['cluster_id'] = self.kmeans.labels_ | |
temp = temp.set_index('cluster_id') | |
agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum" | |
return temp.groupby(temp.index).agg(agg) | |
def extract_reps(self, df): | |
"""Extract the rows of representative policies""" | |
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') | |
temp.index.name = 'cluster_id' | |
return temp.drop('policy_id', axis=1) | |
def extract_and_scale_reps(self, df, agg=None): | |
"""Extract and scale the rows of representative policies""" | |
if agg: | |
cols = df.columns | |
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols}) | |
# Ensure mult has same index as extract_reps(df) for proper alignment | |
extracted_df = self.extract_reps(df) | |
mult.index = extracted_df.index | |
return extracted_df.mul(mult) | |
else: | |
return self.extract_reps(df).mul(self.policy_count, axis=0) | |
def compare(self, df, agg=None): | |
"""Returns a multi-indexed Dataframe comparing actual and estimate""" | |
source = self.agg_by_cluster(df, agg) | |
target = self.extract_and_scale_reps(df, agg) | |
return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()}) | |
def compare_total(self, df, agg=None): | |
"""Aggregate df by columns""" | |
if agg: | |
# Calculate actual values using specified aggregation | |
actual_values = {} | |
for col in df.columns: | |
if agg.get(col, 'sum') == 'mean': | |
actual_values[col] = df[col].mean() | |
else: # sum | |
actual_values[col] = df[col].sum() | |
actual = pd.Series(actual_values) | |
# Calculate estimate values | |
reps_unscaled = self.extract_reps(df) | |
estimate_values = {} | |
for col in df.columns: | |
if agg.get(col, 'sum') == 'mean': | |
# Weighted average for mean columns | |
weighted_sum = (reps_unscaled[col] * self.policy_count).sum() | |
total_weight = self.policy_count.sum() | |
estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0 | |
else: # sum | |
estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum() | |
estimate = pd.Series(estimate_values) | |
else: # Original logic if no agg is specified (all sum) | |
actual = df.sum() | |
estimate = self.extract_and_scale_reps(df).sum() | |
# Calculate error, handling division by zero | |
error = np.where(actual != 0, estimate / actual - 1, 0) | |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error}) | |
def plot_cashflows_comparison(cfs_list, cluster_obj, titles): | |
"""Create cashflow comparison plots""" | |
if not cfs_list or not cluster_obj or not titles: | |
return None | |
num_plots = len(cfs_list) | |
if num_plots == 0: | |
return None | |
# Determine subplot layout | |
cols = 2 | |
rows = (num_plots + cols - 1) // cols | |
fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) | |
axes = axes.flatten() | |
for i, (df, title) in enumerate(zip(cfs_list, titles)): | |
if i < len(axes): | |
comparison = cluster_obj.compare_total(df) | |
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) | |
axes[i].set_xlabel('Time') | |
axes[i].set_ylabel('Value') | |
# Hide any unused subplots | |
for j in range(i + 1, len(axes)): | |
fig.delaxes(axes[j]) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100) | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close(fig) | |
return img | |
def plot_scatter_comparison(df_compare_output, title): | |
"""Create scatter plot comparison from compare() output""" | |
if df_compare_output is None or df_compare_output.empty: | |
# Create a blank plot with a message | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15) | |
ax.set_title(title) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100) | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close(fig) | |
return img | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2: | |
gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.") | |
ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6) | |
else: | |
unique_levels = df_compare_output.index.get_level_values(1).unique() | |
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels))) | |
for item_level, color_val in zip(unique_levels, colors): | |
subset = df_compare_output.xs(item_level, level=1) | |
ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level) | |
if len(unique_levels) > 1 and len(unique_levels) <= 10: | |
ax.legend(title=df_compare_output.index.names[1]) | |
ax.set_xlabel('Actual') | |
ax.set_ylabel('Estimate') | |
ax.set_title(title) | |
ax.grid(True) | |
# Draw identity line | |
lims = [ | |
np.min([ax.get_xlim(), ax.get_ylim()]), | |
np.max([ax.get_xlim(), ax.get_ylim()]), | |
] | |
if lims[0] != lims[1]: | |
ax.plot(lims, lims, 'r-', linewidth=0.5) | |
ax.set_xlim(lims) | |
ax.set_ylim(lims) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100) | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close(fig) | |
return img | |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path, | |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path): | |
"""Main processing function - now accepts file paths""" | |
try: | |
# Read uploaded files using paths | |
cfs = pd.read_excel(cashflow_base_path, index_col=0) | |
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0) | |
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0) | |
pol_data_full = pd.read_excel(policy_data_path, index_col=0) | |
# Ensure the correct columns are selected for pol_data | |
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth'] | |
if all(col in pol_data_full.columns for col in required_cols): | |
pol_data = pol_data_full[required_cols] | |
else: | |
gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}") | |
pol_data = pol_data_full | |
pvs = pd.read_excel(pv_base_path, index_col=0) | |
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0) | |
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0) | |
cfs_list = [cfs, cfs_lapse50, cfs_mort15] | |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] | |
results = {} | |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} | |
# --- 1. Cashflow Calibration --- | |
cluster_cfs = Clusters(cfs) | |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs) | |
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs) | |
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs) | |
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50) | |
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15) | |
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) | |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)') | |
# --- 2. Policy Attribute Calibration --- | |
# Standardize policy attributes | |
if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0: | |
loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min()) | |
else: | |
gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.") | |
loc_vars_attrs = pol_data | |
if not loc_vars_attrs.empty: | |
cluster_attrs = Clusters(loc_vars_attrs) | |
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs) | |
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) | |
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs) | |
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) | |
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)') | |
else: | |
results['attr_total_cf_base'] = pd.DataFrame() | |
results['attr_policy_attrs_total'] = pd.DataFrame() | |
results['attr_total_pv_base'] = pd.DataFrame() | |
results['attr_cashflow_plot'] = None | |
results['attr_scatter_cashflows_base'] = None | |
# --- 3. Present Value Calibration --- | |
cluster_pvs = Clusters(pvs) | |
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs) | |
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs) | |
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs) | |
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) | |
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) | |
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) | |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)') | |
# --- Summary Comparison Plot Data --- | |
# Error metric for key PV column or mean absolute error | |
error_data = {} | |
# Function to safely get error value | |
def get_error_safe(compare_result, col_name=None): | |
if compare_result.empty: | |
return np.nan | |
if col_name and col_name in compare_result.index: | |
return abs(compare_result.loc[col_name, 'error']) | |
else: | |
# Use mean absolute error if specific column not found | |
return abs(compare_result['error']).mean() | |
# Determine key PV column (try common names) | |
key_pv_col = None | |
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']: | |
if potential_col in pvs.columns: | |
key_pv_col = potential_col | |
break | |
# Cashflow Calibration Errors | |
error_data['CF Calib.'] = [ | |
get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col), | |
get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col), | |
get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col) | |
] | |
# Policy Attribute Calibration Errors | |
if not loc_vars_attrs.empty: | |
error_data['Attr Calib.'] = [ | |
get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col), | |
get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), | |
get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) | |
] | |
else: | |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan] | |
# Present Value Calibration Errors | |
error_data['PV Calib.'] = [ | |
get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col), | |
get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col), | |
get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col) | |
] | |
# Create Summary Plot | |
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']) | |
fig_summary, ax_summary = plt.subplots(figsize=(10, 6)) | |
summary_df.plot(kind='bar', ax=ax_summary, grid=True) | |
ax_summary.set_ylabel('Absolute Error Rate') | |
title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)' | |
ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}') | |
ax_summary.tick_params(axis='x', rotation=0) | |
ax_summary.legend(title='Calibration Method') | |
plt.tight_layout() | |
buf_summary = io.BytesIO() | |
plt.savefig(buf_summary, format='png', dpi=100) | |
buf_summary.seek(0) | |
results['summary_plot'] = Image.open(buf_summary) | |
plt.close(fig_summary) | |
return results | |
except FileNotFoundError as e: | |
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.") | |
return {"error": f"File not found: {e.filename}"} | |
except KeyError as e: | |
gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.") | |
return {"error": f"Missing column: {e}"} | |
except Exception as e: | |
gr.Error(f"Error processing files: {str(e)}") | |
return {"error": f"Error processing files: {str(e)}"} | |
def create_interface(): | |
with gr.Blocks(title="Cluster Model Points Analysis") as demo: | |
gr.Markdown(""" | |
# Cluster Model Points Analysis | |
This application applies cluster analysis to model point selection for insurance portfolios. | |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods. | |
**Required Files (Excel .xlsx):** | |
- Cashflows - Base Scenario | |
- Cashflows - Lapse Stress (+50%) | |
- Cashflows - Mortality Stress (+15%) | |
- Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth') | |
- Present Values - Base Scenario | |
- Present Values - Lapse Stress | |
- Present Values - Mortality Stress | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Upload Files or Load Examples") | |
load_example_btn = gr.Button("Load Example Data") | |
with gr.Row(): | |
cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"]) | |
cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"]) | |
cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"]) | |
with gr.Row(): | |
policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"]) | |
pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"]) | |
pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) | |
with gr.Row(): | |
pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) | |
analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg") | |
with gr.Tabs(): | |
with gr.TabItem("📊 Summary"): | |
summary_plot_output = gr.Image(label="Calibration Methods Comparison") | |
with gr.TabItem("💸 Cashflow Calibration"): | |
gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables") | |
with gr.Row(): | |
cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") | |
cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") | |
cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") | |
cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)") | |
with gr.Accordion("Present Value Comparisons (Total)", open=False): | |
with gr.Row(): | |
cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total") | |
cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total") | |
cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total") | |
with gr.TabItem("👤 Policy Attribute Calibration"): | |
gr.Markdown("### Results: Using Policy Attributes as Calibration Variables") | |
with gr.Row(): | |
attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") | |
attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") | |
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") | |
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)") | |
with gr.Accordion("Present Value Comparisons (Total)", open=False): | |
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total") | |
with gr.TabItem("💰 Present Value Calibration"): | |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables") | |
with gr.Row(): | |
pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)") | |
pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes") | |
pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios") | |
pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)") | |
with gr.Accordion("Present Value Comparisons (Total)", open=False): | |
with gr.Row(): | |
pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total") | |
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total") | |
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total") | |
# --- Helper function to prepare outputs --- | |
def get_all_output_components(): | |
return [ | |
summary_plot_output, | |
# Cashflow Calib Outputs | |
cf_total_base_table_out, cf_policy_attrs_total_out, | |
cf_cashflow_plot_out, cf_scatter_cashflows_base_out, | |
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out, | |
# Attribute Calib Outputs | |
attr_total_cf_base_out, attr_policy_attrs_total_out, | |
attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out, | |
# PV Calib Outputs | |
pv_total_cf_base_out, pv_policy_attrs_total_out, | |
pv_cashflow_plot_out, pv_scatter_pvs_base_out, | |
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out | |
] | |
# --- Action for Analyze Button --- | |
def handle_analysis(f1, f2, f3, f4, f5, f6, f7): | |
files = [f1, f2, f3, f4, f5, f6, f7] | |
file_paths = [] | |
for i, f_obj in enumerate(files): | |
if f_obj is None: | |
gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.") | |
return [None] * len(get_all_output_components()) | |
# If f_obj is a Gradio FileData object (from direct upload) | |
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): | |
file_paths.append(f_obj.name) | |
# If f_obj is already a string path (from example load) | |
elif isinstance(f_obj, str): | |
file_paths.append(f_obj) | |
else: | |
gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}") | |
return [None] * len(get_all_output_components()) | |
results = process_files(*file_paths) | |
if "error" in results: | |
return [None] * len(get_all_output_components()) | |
return [ | |
results.get('summary_plot'), | |
# CF Calib | |
results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'), | |
results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'), | |
results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'), | |
# Attr Calib | |
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'), | |
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'), | |
# PV Calib | |
results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'), | |
results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'), | |
results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort') | |
] | |
analyze_btn.click( | |
handle_analysis, | |
inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, | |
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input], | |
outputs=get_all_output_components() | |
) | |
# --- Action for Load Example Data Button --- | |
def load_example_files(): | |
missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)] | |
if missing_files: | |
gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.") | |
return [None] * 7 | |
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.") | |
return [ | |
EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"], | |
EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"], | |
EXAMPLE_FILES["pv_mort"] | |
] | |
load_example_btn.click( | |
load_example_files, | |
inputs=[], | |
outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input, | |
policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input] | |
) | |
return demo | |
if __name__ == "__main__": | |
if not os.path.exists(EXAMPLE_DATA_DIR): | |
os.makedirs(EXAMPLE_DATA_DIR) | |
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.") | |
print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}") | |
demo_app = create_interface() | |
demo_app.launch() |