alidenewade commited on
Commit
4072b44
·
verified ·
1 Parent(s): 4622ce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -247
app.py CHANGED
@@ -3,10 +3,11 @@ import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
  from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
- import plotly.graph_objects as go
7
- import plotly.express as px
8
- from plotly.subplots import make_subplots
9
- import os
 
10
 
11
  # Define the paths for example data
12
  EXAMPLE_DATA_DIR = "eg_data"
@@ -25,7 +26,7 @@ class Clusters:
25
  self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
26
  closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
27
 
28
- rep_ids = pd.Series(data=(closest+1))
29
  rep_ids.name = 'policy_id'
30
  rep_ids.index.name = 'cluster_id'
31
  self.rep_ids = rep_ids
@@ -33,6 +34,7 @@ class Clusters:
33
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
34
 
35
  def agg_by_cluster(self, df, agg=None):
 
36
  temp = df.copy()
37
  temp['cluster_id'] = self.kmeans.labels_
38
  temp = temp.set_index('cluster_id')
@@ -40,14 +42,17 @@ class Clusters:
40
  return temp.groupby(temp.index).agg(agg)
41
 
42
  def extract_reps(self, df):
 
43
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
44
  temp.index.name = 'cluster_id'
45
  return temp.drop('policy_id', axis=1)
46
 
47
  def extract_and_scale_reps(self, df, agg=None):
 
48
  if agg:
49
  cols = df.columns
50
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
 
51
  extracted_df = self.extract_reps(df)
52
  mult.index = extracted_df.index
53
  return extracted_df.mul(mult)
@@ -55,188 +60,145 @@ class Clusters:
55
  return self.extract_reps(df).mul(self.policy_count, axis=0)
56
 
57
  def compare(self, df, agg=None):
 
58
  source = self.agg_by_cluster(df, agg)
59
  target = self.extract_and_scale_reps(df, agg)
60
  return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
61
 
62
  def compare_total(self, df, agg=None):
 
63
  if agg:
 
64
  actual_values = {}
65
  for col in df.columns:
66
  if agg.get(col, 'sum') == 'mean':
67
  actual_values[col] = df[col].mean()
68
- else:
69
  actual_values[col] = df[col].sum()
70
  actual = pd.Series(actual_values)
71
 
 
72
  reps_unscaled = self.extract_reps(df)
73
  estimate_values = {}
74
 
75
  for col in df.columns:
76
  if agg.get(col, 'sum') == 'mean':
 
77
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
78
  total_weight = self.policy_count.sum()
79
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
80
- else:
81
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
 
82
  estimate = pd.Series(estimate_values)
83
- else:
 
84
  actual = df.sum()
85
  estimate = self.extract_and_scale_reps(df).sum()
86
 
 
87
  error = np.where(actual != 0, estimate / actual - 1, 0)
 
88
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
89
 
90
 
91
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
92
- fig_width = 900 # Reduced width
93
- default_height_per_row = 300 # Reduced height per row
94
-
95
  if not cfs_list or not cluster_obj or not titles:
96
- fig = go.Figure()
97
- fig.add_annotation(text="No data for cashflow comparison plot.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
98
- fig.update_layout(width=fig_width, height=default_height_per_row)
99
- return fig
100
-
101
  num_plots = len(cfs_list)
102
  if num_plots == 0:
103
- fig = go.Figure()
104
- fig.add_annotation(text="No cashflows to plot.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
105
- fig.update_layout(width=fig_width, height=default_height_per_row)
106
- return fig
107
 
 
108
  cols = 2
109
  rows = (num_plots + cols - 1) // cols
110
- fig_height = default_height_per_row * rows
111
 
112
- subplot_titles_full = titles[:num_plots] + [""] * (rows * cols - num_plots)
113
-
114
- try:
115
- fig = make_subplots(
116
- rows=rows, cols=cols,
117
- subplot_titles=subplot_titles_full
118
- )
119
-
120
- plot_idx = 0
121
- for i_df, (df, title) in enumerate(zip(cfs_list, titles)):
122
- if plot_idx < rows * cols:
123
- r = plot_idx // cols + 1
124
- c = plot_idx % cols + 1
125
- comparison = cluster_obj.compare_total(df)
126
-
127
- fig.add_trace(go.Scatter(x=comparison.index, y=comparison['actual'], name='Actual',
128
- legendgroup='Actual', showlegend=(plot_idx == 0)), row=r, col=c)
129
- fig.add_trace(go.Scatter(x=comparison.index, y=comparison['estimate'], name='Estimate',
130
- legendgroup='Estimate', showlegend=(plot_idx == 0)), row=r, col=c)
131
-
132
- fig.update_xaxes(title_text='Time', showgrid=True, row=r, col=c)
133
- fig.update_yaxes(title_text='Value', showgrid=True, row=r, col=c)
134
- plot_idx += 1
135
 
136
- for i in range(plot_idx, rows * cols):
137
- r = i // cols + 1
138
- c = i % cols + 1
139
- fig.update_xaxes(visible=False, row=r, col=c)
140
- fig.update_yaxes(visible=False, row=r, col=c)
141
- if fig.layout.annotations and i < len(fig.layout.annotations):
142
- fig.layout.annotations[i].update(text="")
143
-
144
- fig.update_layout(
145
- width=fig_width,
146
- height=fig_height,
147
- margin=dict(l=60, r=30, t=60, b=60) # Keep reasonable margins
148
- )
149
- return fig
150
- except Exception as e:
151
- print(f"Error generating cashflow plot: {e}")
152
- fig = go.Figure()
153
- fig.add_annotation(text=f"Plot Error: {e}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
154
- fig.update_layout(width=fig_width, height=fig_height if rows > 0 else default_height_per_row)
155
- return fig
156
-
157
 
158
  def plot_scatter_comparison(df_compare_output, title):
159
- fig_width = 800 # Reduced width
160
- fig_height = 550 # Reduced height
161
-
162
- try:
163
- if df_compare_output is None or df_compare_output.empty:
164
- fig = go.Figure()
165
- fig.add_annotation(
166
- text="No data to display for scatter plot.",
167
- xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
168
- font=dict(size=15)
169
- )
170
- fig.update_layout(title_text=title, width=fig_width, height=fig_height)
171
- return fig
172
-
173
- if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
174
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
175
- fig = px.scatter(df_compare_output, x='actual', y='estimate', title=title)
176
- fig.update_traces(marker=dict(size=5, opacity=0.6))
177
- else:
178
- df_reset = df_compare_output.reset_index()
179
- level_1_name = df_compare_output.index.names[1] if df_compare_output.index.names[1] else 'category'
180
- if level_1_name not in df_reset.columns and len(df_reset.columns) > 1:
181
- df_reset = df_reset.rename(columns={df_reset.columns[1]: level_1_name})
182
-
183
- fig = px.scatter(df_reset, x='actual', y='estimate', color=level_1_name,
184
- title=title,
185
- labels={'actual': 'Actual', 'estimate': 'Estimate', level_1_name: level_1_name})
186
- fig.update_traces(marker=dict(size=5, opacity=0.6))
187
-
188
- num_unique_levels = df_reset[level_1_name].nunique() if level_1_name in df_reset else 0
189
- if num_unique_levels == 0 or num_unique_levels > 10:
190
- fig.update_layout(showlegend=False)
191
- elif num_unique_levels >= 1 :
192
- fig.update_layout(showlegend=True)
193
-
194
- fig.update_xaxes(showgrid=True, title_text='Actual')
195
- fig.update_yaxes(showgrid=True, title_text='Estimate')
196
-
197
- if not df_compare_output.empty:
198
- min_val_actual = df_compare_output['actual'].min()
199
- max_val_actual = df_compare_output['actual'].max()
200
- min_val_estimate = df_compare_output['estimate'].min()
201
- max_val_estimate = df_compare_output['estimate'].max()
202
-
203
- if pd.isna(min_val_actual) or pd.isna(min_val_estimate) or pd.isna(max_val_actual) or pd.isna(max_val_estimate):
204
- lims = [0,1]
205
- else:
206
- overall_min = min(min_val_actual, min_val_estimate)
207
- overall_max = max(max_val_actual, max_val_estimate)
208
- lims = [overall_min, overall_max]
209
-
210
- if lims[0] != lims[1]:
211
- fig.add_trace(go.Scatter(
212
- x=lims, y=lims, mode='lines', name='Identity',
213
- line=dict(color='red', width=1),
214
- showlegend=False
215
- ))
216
- fig.update_xaxes(range=lims)
217
- fig.update_yaxes(range=lims, scaleanchor="x", scaleratio=1)
218
-
219
- fig.update_layout(width=fig_width, height=fig_height)
220
- return fig
221
 
222
- except Exception as e:
223
- print(f"Error generating scatter plot: {e}")
224
- fig = go.Figure()
225
- fig.add_annotation(text=f"Plot Error: {e}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
226
- fig.update_layout(width=fig_width, height=fig_height, title_text=title)
227
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
 
230
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
231
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
232
- fig_width_summary = 700 # Reduced width
233
- fig_height_summary = 420 # Reduced height
234
  try:
 
235
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
236
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
237
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
238
 
239
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
240
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
241
  if all(col in pol_data_full.columns for col in required_cols):
242
  pol_data = pol_data_full[required_cols]
@@ -252,139 +214,139 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
252
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
253
 
254
  results = {}
 
255
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
256
 
 
257
  cluster_cfs = Clusters(cfs)
 
258
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
259
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
260
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
261
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
262
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
263
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
264
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
265
 
266
- pol_data_numeric = pol_data.apply(pd.to_numeric, errors='coerce')
267
- pol_data_numeric.dropna(axis=1, how='all', inplace=True)
268
-
269
- if not pol_data_numeric.empty:
270
- min_vals = pol_data_numeric.min()
271
- max_vals = pol_data_numeric.max()
272
- range_vals = max_vals - min_vals
273
- loc_vars_attrs = pol_data_numeric.copy()
274
- for col in pol_data_numeric.columns:
275
- if range_vals[col] != 0:
276
- loc_vars_attrs[col] = (pol_data_numeric[col] - min_vals[col]) / range_vals[col]
277
- else:
278
- loc_vars_attrs[col] = 0
279
- loc_vars_attrs = loc_vars_attrs.fillna(0)
280
  else:
281
- gr.Warning("Policy data for attribute calibration is empty or non-numeric. Skipping attribute calibration.")
282
- loc_vars_attrs = pd.DataFrame()
283
-
284
  if not loc_vars_attrs.empty:
285
- try:
286
- cluster_attrs = Clusters(loc_vars_attrs)
287
- results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
288
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
289
- results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
290
- results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
291
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
292
- except Exception as e_attr_clust:
293
- gr.Error(f"Error during policy attribute clustering: {e_attr_clust}")
294
- results['attr_total_cf_base'], results['attr_policy_attrs_total'], results['attr_total_pv_base'] = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
295
- results['attr_cashflow_plot'] = plot_cashflows_comparison([], None, [])
296
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) Error')
297
  else:
298
- gr.Warning("Skipping attribute calibration as data is empty or non-numeric after processing.")
299
- results['attr_total_cf_base'], results['attr_policy_attrs_total'], results['attr_total_pv_base'] = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
300
- results['attr_cashflow_plot'] = plot_cashflows_comparison([], None, [])
301
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) No Data')
302
-
303
 
 
304
  cluster_pvs = Clusters(pvs)
 
305
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
306
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
307
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
308
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
309
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
 
310
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
311
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
312
 
 
 
 
313
  error_data = {}
 
 
314
  def get_error_safe(compare_result, col_name=None):
315
- if compare_result.empty: return np.nan
316
- if col_name and col_name in compare_result.index: return abs(compare_result.loc[col_name, 'error'])
317
- return abs(compare_result['error']).mean()
 
 
 
 
318
 
319
- key_pv_col = next((pcol for pcol in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF'] if pcol in pvs.columns), None)
 
 
 
 
 
320
 
 
321
  error_data['CF Calib.'] = [
322
- get_error_safe(results['cf_pv_total_base'], key_pv_col),
323
- get_error_safe(results['cf_pv_total_lapse'], key_pv_col),
324
- get_error_safe(results['cf_pv_total_mort'], key_pv_col)
325
  ]
326
- if results.get('attr_total_pv_base') is not None and not results['attr_total_pv_base'].empty and 'cluster_attrs' in locals():
 
 
327
  error_data['Attr Calib.'] = [
328
- get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
329
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
330
  get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
331
  ]
332
  else:
333
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
334
 
 
335
  error_data['PV Calib.'] = [
336
- get_error_safe(results['pv_total_pv_base'], key_pv_col),
337
- get_error_safe(results['pv_total_pv_lapse'], key_pv_col),
338
- get_error_safe(results['pv_total_pv_mort'], key_pv_col)
339
  ]
340
 
 
341
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
 
 
 
 
342
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
343
- plot_title = f'Calibration Method Comparison - Error in Total PV{title_suffix}'
 
 
 
 
 
 
 
 
 
344
 
345
- summary_df_melted = summary_df.reset_index().melt(id_vars='index', var_name='Calibration Method', value_name='Absolute Error Rate')
346
- summary_df_melted.rename(columns={'index': 'Scenario'}, inplace=True)
347
-
348
- fig_summary = px.bar(
349
- summary_df_melted, x='Scenario', y='Absolute Error Rate',
350
- color='Calibration Method', barmode='group', title=plot_title
351
- )
352
- fig_summary.update_layout(
353
- width=fig_width_summary, height=fig_height_summary,
354
- xaxis_tickangle=0, yaxis_title='Absolute Error Rate', legend_title_text='Calibration Method'
355
- )
356
- fig_summary.update_yaxes(showgrid=True)
357
- results['summary_plot'] = fig_summary
358
-
359
  return results
360
 
361
  except FileNotFoundError as e:
362
- gr.Error(f"File not found: {e.filename}.")
363
  return {"error": f"File not found: {e.filename}"}
364
  except KeyError as e:
365
- gr.Error(f"A required column is missing: {e}.")
366
  return {"error": f"Missing column: {e}"}
367
  except Exception as e:
368
  gr.Error(f"Error processing files: {str(e)}")
369
- import traceback
370
- traceback.print_exc()
371
- error_fig = go.Figure()
372
- error_fig.add_annotation(text=f"Processing Error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
373
- error_fig.update_layout(width=fig_width_summary, height=fig_height_summary) # Use some default size
374
- # Return error plots for all expected plot keys
375
- plot_keys = ["summary_plot", "cf_cashflow_plot", "cf_scatter_cashflows_base",
376
- "attr_cashflow_plot", "attr_scatter_cashflows_base",
377
- "pv_cashflow_plot", "pv_scatter_pvs_base"]
378
- error_results = {"error": f"Error processing files: {str(e)}"}
379
- for key in plot_keys:
380
- error_results[key] = error_fig
381
- return error_results
382
 
383
 
384
  def create_interface():
385
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
386
  gr.Markdown("""
387
  # Cluster Model Points Analysis
 
388
  This application applies cluster analysis to model point selection for insurance portfolios.
389
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
390
 
@@ -396,14 +358,14 @@ def create_interface():
396
  - Present Values - Base Scenario
397
  - Present Values - Lapse Stress
398
  - Present Values - Mortality Stress
399
-
400
- **Note:** Plots are interactive. Hover over data points for details.
401
  """)
402
 
403
  with gr.Row():
404
  with gr.Column(scale=1):
405
  gr.Markdown("### Upload Files or Load Examples")
 
406
  load_example_btn = gr.Button("Load Example Data")
 
407
  with gr.Row():
408
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
409
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
@@ -414,19 +376,20 @@ def create_interface():
414
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
415
  with gr.Row():
416
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
 
417
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
418
 
419
  with gr.Tabs():
420
  with gr.TabItem("📊 Summary"):
421
- summary_plot_output = gr.Plot(label="Calibration Methods Comparison")
422
 
423
  with gr.TabItem("💸 Cashflow Calibration"):
424
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
425
  with gr.Row():
426
  cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
427
  cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
428
- cf_cashflow_plot_out = gr.Plot(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
429
- cf_scatter_cashflows_base_out = gr.Plot(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
430
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
431
  with gr.Row():
432
  cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
@@ -438,8 +401,8 @@ def create_interface():
438
  with gr.Row():
439
  attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
440
  attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
441
- attr_cashflow_plot_out = gr.Plot(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
442
- attr_scatter_cashflows_base_out = gr.Plot(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
443
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
444
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
445
 
@@ -448,36 +411,45 @@ def create_interface():
448
  with gr.Row():
449
  pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
450
  pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
451
- pv_cashflow_plot_out = gr.Plot(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
452
- pv_scatter_pvs_base_out = gr.Plot(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
453
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
454
  with gr.Row():
455
  pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
456
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
457
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
458
 
 
459
  def get_all_output_components():
460
  return [
461
  summary_plot_output,
 
462
  cf_total_base_table_out, cf_policy_attrs_total_out,
463
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
464
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
 
465
  attr_total_cf_base_out, attr_policy_attrs_total_out,
466
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
 
467
  pv_total_cf_base_out, pv_policy_attrs_total_out,
468
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
469
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
470
  ]
471
 
 
472
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
473
  files = [f1, f2, f3, f4, f5, f6, f7]
 
474
  file_paths = []
475
  for i, f_obj in enumerate(files):
476
  if f_obj is None:
477
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
478
- return [None] * len(get_all_output_components())
 
 
479
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
480
  file_paths.append(f_obj.name)
 
481
  elif isinstance(f_obj, str):
482
  file_paths.append(f_obj)
483
  else:
@@ -486,37 +458,23 @@ def create_interface():
486
 
487
  results = process_files(*file_paths)
488
 
489
- default_error_plot = go.Figure()
490
- default_error_plot.add_annotation(text="Analysis Error or No Data.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
491
- default_error_plot.update_layout(width=600, height=400) # Generic small size for error plot
492
-
493
- output_list = [
494
- results.get('summary_plot', default_error_plot),
495
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
496
- results.get('cf_cashflow_plot', default_error_plot),
497
- results.get('cf_scatter_cashflows_base', default_error_plot),
498
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
499
-
500
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
501
- results.get('attr_cashflow_plot', default_error_plot),
502
- results.get('attr_scatter_cashflows_base', default_error_plot),
503
- results.get('attr_total_pv_base'),
504
-
505
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
506
- results.get('pv_cashflow_plot', default_error_plot),
507
- results.get('pv_scatter_pvs_base', default_error_plot),
508
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
509
  ]
510
- # Ensure dataframes are None if not found, not error plots
511
- df_indices = [1, 2, 5, 6, 7, 8, 9, 12, 13, 14, 17,18,19] # Indices of dataframe outputs
512
- for idx in df_indices:
513
- if not isinstance(output_list[idx], pd.DataFrame) and output_list[idx] is not None :
514
- if results.get("error") and output_list[idx] is default_error_plot: # if it's an error plot because of main error
515
- output_list[idx] = None # set df to None
516
- elif not results.get("error"): # if no main error, but specific df missing
517
- output_list[idx] = pd.DataFrame() # set to empty df
518
- return output_list
519
-
520
 
521
  analyze_btn.click(
522
  handle_analysis,
@@ -525,11 +483,12 @@ def create_interface():
525
  outputs=get_all_output_components()
526
  )
527
 
 
528
  def load_example_files():
529
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
530
  if missing_files:
531
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
532
- return [None] * 7
533
 
534
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
535
  return [
 
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
  from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.cm
8
+ import io
9
+ import os # Added for path joining
10
+ from PIL import Image
11
 
12
  # Define the paths for example data
13
  EXAMPLE_DATA_DIR = "eg_data"
 
26
  self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
  closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
 
29
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
30
  rep_ids.name = 'policy_id'
31
  rep_ids.index.name = 'cluster_id'
32
  self.rep_ids = rep_ids
 
34
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
 
36
  def agg_by_cluster(self, df, agg=None):
37
+ """Aggregate columns by cluster"""
38
  temp = df.copy()
39
  temp['cluster_id'] = self.kmeans.labels_
40
  temp = temp.set_index('cluster_id')
 
42
  return temp.groupby(temp.index).agg(agg)
43
 
44
  def extract_reps(self, df):
45
+ """Extract the rows of representative policies"""
46
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
  temp.index.name = 'cluster_id'
48
  return temp.drop('policy_id', axis=1)
49
 
50
  def extract_and_scale_reps(self, df, agg=None):
51
+ """Extract and scale the rows of representative policies"""
52
  if agg:
53
  cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+ # Ensure mult has same index as extract_reps(df) for proper alignment
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
 
60
  return self.extract_reps(df).mul(self.policy_count, axis=0)
61
 
62
  def compare(self, df, agg=None):
63
+ """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
  source = self.agg_by_cluster(df, agg)
65
  target = self.extract_and_scale_reps(df, agg)
66
  return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
67
 
68
  def compare_total(self, df, agg=None):
69
+ """Aggregate df by columns"""
70
  if agg:
71
+ # Calculate actual values using specified aggregation
72
  actual_values = {}
73
  for col in df.columns:
74
  if agg.get(col, 'sum') == 'mean':
75
  actual_values[col] = df[col].mean()
76
+ else: # sum
77
  actual_values[col] = df[col].sum()
78
  actual = pd.Series(actual_values)
79
 
80
+ # Calculate estimate values
81
  reps_unscaled = self.extract_reps(df)
82
  estimate_values = {}
83
 
84
  for col in df.columns:
85
  if agg.get(col, 'sum') == 'mean':
86
+ # Weighted average for mean columns
87
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
  total_weight = self.policy_count.sum()
89
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
90
+ else: # sum
91
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
92
+
93
  estimate = pd.Series(estimate_values)
94
+
95
+ else: # Original logic if no agg is specified (all sum)
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
+ # Calculate error, handling division by zero
100
  error = np.where(actual != 0, estimate / actual - 1, 0)
101
+
102
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
 
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
+ """Create cashflow comparison plots"""
 
 
107
  if not cfs_list or not cluster_obj or not titles:
108
+ return None
 
 
 
 
109
  num_plots = len(cfs_list)
110
  if num_plots == 0:
111
+ return None
 
 
 
112
 
113
+ # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
 
116
 
117
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
118
+ axes = axes.flatten()
119
+
120
+ for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
+ if i < len(axes):
122
+ comparison = cluster_obj.compare_total(df)
123
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
124
+ axes[i].set_xlabel('Time')
125
+ axes[i].set_ylabel('Value')
126
+
127
+ # Hide any unused subplots
128
+ for j in range(i + 1, len(axes)):
129
+ fig.delaxes(axes[j])
 
 
 
 
 
 
 
 
 
 
130
 
131
+ plt.tight_layout()
132
+ buf = io.BytesIO()
133
+ plt.savefig(buf, format='png', dpi=100)
134
+ buf.seek(0)
135
+ img = Image.open(buf)
136
+ plt.close(fig)
137
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  def plot_scatter_comparison(df_compare_output, title):
140
+ """Create scatter plot comparison from compare() output"""
141
+ if df_compare_output is None or df_compare_output.empty:
142
+ # Create a blank plot with a message
143
+ fig, ax = plt.subplots(figsize=(12, 8))
144
+ ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
145
+ ax.set_title(title)
146
+ buf = io.BytesIO()
147
+ plt.savefig(buf, format='png', dpi=100)
148
+ buf.seek(0)
149
+ img = Image.open(buf)
150
+ plt.close(fig)
151
+ return img
152
+
153
+ fig, ax = plt.subplots(figsize=(12, 8))
154
+
155
+ if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
157
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
158
+ else:
159
+ unique_levels = df_compare_output.index.get_level_values(1).unique()
160
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ for item_level, color_val in zip(unique_levels, colors):
163
+ subset = df_compare_output.xs(item_level, level=1)
164
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
+ if len(unique_levels) > 1 and len(unique_levels) <= 10:
166
+ ax.legend(title=df_compare_output.index.names[1])
167
+
168
+ ax.set_xlabel('Actual')
169
+ ax.set_ylabel('Estimate')
170
+ ax.set_title(title)
171
+ ax.grid(True)
172
+
173
+ # Draw identity line
174
+ lims = [
175
+ np.min([ax.get_xlim(), ax.get_ylim()]),
176
+ np.max([ax.get_xlim(), ax.get_ylim()]),
177
+ ]
178
+ if lims[0] != lims[1]:
179
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
180
+ ax.set_xlim(lims)
181
+ ax.set_ylim(lims)
182
+
183
+ buf = io.BytesIO()
184
+ plt.savefig(buf, format='png', dpi=100)
185
+ buf.seek(0)
186
+ img = Image.open(buf)
187
+ plt.close(fig)
188
+ return img
189
 
190
 
191
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
192
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
193
+ """Main processing function - now accepts file paths"""
 
194
  try:
195
+ # Read uploaded files using paths
196
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
197
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
198
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
199
 
200
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
201
+ # Ensure the correct columns are selected for pol_data
202
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
203
  if all(col in pol_data_full.columns for col in required_cols):
204
  pol_data = pol_data_full[required_cols]
 
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
215
 
216
  results = {}
217
+
218
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
220
+ # --- 1. Cashflow Calibration ---
221
  cluster_cfs = Clusters(cfs)
222
+
223
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
224
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
225
+
226
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
227
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
228
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
229
+
230
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
231
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
232
 
233
+ # --- 2. Policy Attribute Calibration ---
234
+ # Standardize policy attributes
235
+ if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
236
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
 
 
 
 
 
 
 
 
 
 
237
  else:
238
+ gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
+ loc_vars_attrs = pol_data
240
+
241
  if not loc_vars_attrs.empty:
242
+ cluster_attrs = Clusters(loc_vars_attrs)
243
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
244
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
245
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
246
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
247
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
 
 
 
 
 
 
248
  else:
249
+ results['attr_total_cf_base'] = pd.DataFrame()
250
+ results['attr_policy_attrs_total'] = pd.DataFrame()
251
+ results['attr_total_pv_base'] = pd.DataFrame()
252
+ results['attr_cashflow_plot'] = None
253
+ results['attr_scatter_cashflows_base'] = None
254
 
255
+ # --- 3. Present Value Calibration ---
256
  cluster_pvs = Clusters(pvs)
257
+
258
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
259
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
260
+
261
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
262
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
263
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
264
+
265
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
266
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
267
 
268
+ # --- Summary Comparison Plot Data ---
269
+ # Error metric for key PV column or mean absolute error
270
+
271
  error_data = {}
272
+
273
+ # Function to safely get error value
274
  def get_error_safe(compare_result, col_name=None):
275
+ if compare_result.empty:
276
+ return np.nan
277
+ if col_name and col_name in compare_result.index:
278
+ return abs(compare_result.loc[col_name, 'error'])
279
+ else:
280
+ # Use mean absolute error if specific column not found
281
+ return abs(compare_result['error']).mean()
282
 
283
+ # Determine key PV column (try common names)
284
+ key_pv_col = None
285
+ for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
286
+ if potential_col in pvs.columns:
287
+ key_pv_col = potential_col
288
+ break
289
 
290
+ # Cashflow Calibration Errors
291
  error_data['CF Calib.'] = [
292
+ get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
293
+ get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
294
+ get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
295
  ]
296
+
297
+ # Policy Attribute Calibration Errors
298
+ if not loc_vars_attrs.empty:
299
  error_data['Attr Calib.'] = [
300
+ get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
301
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
302
  get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
303
  ]
304
  else:
305
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
306
 
307
+ # Present Value Calibration Errors
308
  error_data['PV Calib.'] = [
309
+ get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
310
+ get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
311
+ get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
312
  ]
313
 
314
+ # Create Summary Plot
315
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
316
+
317
+ fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
318
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
319
+ ax_summary.set_ylabel('Absolute Error Rate')
320
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
321
+ ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
322
+ ax_summary.tick_params(axis='x', rotation=0)
323
+ ax_summary.legend(title='Calibration Method')
324
+ plt.tight_layout()
325
+
326
+ buf_summary = io.BytesIO()
327
+ plt.savefig(buf_summary, format='png', dpi=100)
328
+ buf_summary.seek(0)
329
+ results['summary_plot'] = Image.open(buf_summary)
330
+ plt.close(fig_summary)
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  return results
333
 
334
  except FileNotFoundError as e:
335
+ gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
336
  return {"error": f"File not found: {e.filename}"}
337
  except KeyError as e:
338
+ gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
339
  return {"error": f"Missing column: {e}"}
340
  except Exception as e:
341
  gr.Error(f"Error processing files: {str(e)}")
342
+ return {"error": f"Error processing files: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
 
345
  def create_interface():
346
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
347
  gr.Markdown("""
348
  # Cluster Model Points Analysis
349
+
350
  This application applies cluster analysis to model point selection for insurance portfolios.
351
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
352
 
 
358
  - Present Values - Base Scenario
359
  - Present Values - Lapse Stress
360
  - Present Values - Mortality Stress
 
 
361
  """)
362
 
363
  with gr.Row():
364
  with gr.Column(scale=1):
365
  gr.Markdown("### Upload Files or Load Examples")
366
+
367
  load_example_btn = gr.Button("Load Example Data")
368
+
369
  with gr.Row():
370
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
371
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
 
376
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
377
  with gr.Row():
378
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
379
+
380
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
381
 
382
  with gr.Tabs():
383
  with gr.TabItem("📊 Summary"):
384
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison")
385
 
386
  with gr.TabItem("💸 Cashflow Calibration"):
387
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
388
  with gr.Row():
389
  cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
390
  cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
391
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
392
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
393
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
394
  with gr.Row():
395
  cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
 
401
  with gr.Row():
402
  attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
403
  attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
404
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
405
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
406
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
407
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
408
 
 
411
  with gr.Row():
412
  pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
413
  pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
414
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
415
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
416
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
417
  with gr.Row():
418
  pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
419
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
420
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
421
 
422
+ # --- Helper function to prepare outputs ---
423
  def get_all_output_components():
424
  return [
425
  summary_plot_output,
426
+ # Cashflow Calib Outputs
427
  cf_total_base_table_out, cf_policy_attrs_total_out,
428
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
429
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
430
+ # Attribute Calib Outputs
431
  attr_total_cf_base_out, attr_policy_attrs_total_out,
432
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
433
+ # PV Calib Outputs
434
  pv_total_cf_base_out, pv_policy_attrs_total_out,
435
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
436
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
437
  ]
438
 
439
+ # --- Action for Analyze Button ---
440
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
441
  files = [f1, f2, f3, f4, f5, f6, f7]
442
+
443
  file_paths = []
444
  for i, f_obj in enumerate(files):
445
  if f_obj is None:
446
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
447
+ return [None] * len(get_all_output_components())
448
+
449
+ # If f_obj is a Gradio FileData object (from direct upload)
450
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
451
  file_paths.append(f_obj.name)
452
+ # If f_obj is already a string path (from example load)
453
  elif isinstance(f_obj, str):
454
  file_paths.append(f_obj)
455
  else:
 
458
 
459
  results = process_files(*file_paths)
460
 
461
+ if "error" in results:
462
+ return [None] * len(get_all_output_components())
463
+
464
+ return [
465
+ results.get('summary_plot'),
466
+ # CF Calib
467
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
468
+ results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
 
469
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
470
+ # Attr Calib
471
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
472
+ results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
473
+ # PV Calib
 
 
474
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
475
+ results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
 
476
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
477
  ]
 
 
 
 
 
 
 
 
 
 
478
 
479
  analyze_btn.click(
480
  handle_analysis,
 
483
  outputs=get_all_output_components()
484
  )
485
 
486
+ # --- Action for Load Example Data Button ---
487
  def load_example_files():
488
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
489
  if missing_files:
490
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
491
+ return [None] * 7
492
 
493
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
494
  return [