alidenewade commited on
Commit
9d69ed1
·
verified ·
1 Parent(s): 30f0b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -216
app.py CHANGED
@@ -6,15 +6,27 @@ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
- import base64
10
  from PIL import Image
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class Clusters:
13
  def __init__(self, loc_vars):
14
  self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
15
  closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
16
 
17
- rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
18
  rep_ids.name = 'policy_id'
19
  rep_ids.index.name = 'cluster_id'
20
  self.rep_ids = rep_ids
@@ -26,7 +38,7 @@ class Clusters:
26
  temp = df.copy()
27
  temp['cluster_id'] = self.kmeans.labels_
28
  temp = temp.set_index('cluster_id')
29
- agg = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} if agg else "sum"
30
  return temp.groupby(temp.index).agg(agg)
31
 
32
  def extract_reps(self, df):
@@ -40,7 +52,10 @@ class Clusters:
40
  if agg:
41
  cols = df.columns
42
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
43
- return self.extract_reps(df).mul(mult)
 
 
 
44
  else:
45
  return self.extract_reps(df).mul(self.policy_count, axis=0)
46
 
@@ -53,307 +68,504 @@ class Clusters:
53
  def compare_total(self, df, agg=None):
54
  """Aggregate df by columns"""
55
  if agg:
56
- cols = df.columns
57
  op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
58
  actual = df.agg(op)
59
- estimate = self.extract_and_scale_reps(df, agg=op)
60
 
61
- op = {k: ((lambda s: s.dot(self.policy_count) / self.policy_count.sum()) if v == 'mean' else v) for k, v in op.items()}
62
- estimate = estimate.agg(op)
63
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  actual = df.sum()
65
  estimate = self.extract_and_scale_reps(df).sum()
66
 
67
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
68
 
69
- def create_plot(plot_func, *args, **kwargs):
70
- """Helper function to create plots and return as image"""
71
- plt.figure(figsize=(10, 6))
72
- plot_func(*args, **kwargs)
73
-
74
- # Save plot to bytes
75
- buf = io.BytesIO()
76
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
77
- buf.seek(0)
78
- plt.close()
79
-
80
- return Image.open(buf)
81
 
82
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
83
  """Create cashflow comparison plots"""
84
- fig, axes = plt.subplots(2, 2, figsize=(15, 10))
 
 
 
 
 
 
 
 
 
 
85
  axes = axes.flatten()
86
 
87
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
88
  if i < len(axes):
89
  comparison = cluster_obj.compare_total(df)
90
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
 
 
91
 
 
 
 
 
92
  plt.tight_layout()
93
  buf = io.BytesIO()
94
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
95
  buf.seek(0)
96
- plt.close()
97
-
98
- return Image.open(buf)
99
 
100
- def plot_scatter_comparison(df, title):
101
- """Create scatter plot comparison"""
102
- plt.figure(figsize=(12, 8))
103
-
104
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(df.index.levels[1])))
105
-
106
- for y, c in zip(df.index.levels[1], colors):
107
- plt.scatter(df.xs(y, level=1)['actual'], df.xs(y, level=1)['estimate'],
108
- color=c, s=9, alpha=0.6)
 
 
 
 
 
 
109
 
110
- plt.xlabel('Actual')
111
- plt.ylabel('Estimate')
112
- plt.title(title)
113
- plt.grid(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Draw identity line
116
  lims = [
117
- np.min([plt.xlim(), plt.ylim()]),
118
- np.max([plt.xlim(), plt.ylim()]),
119
  ]
120
- plt.plot(lims, lims, 'r-', linewidth=0.5)
121
- plt.xlim(lims)
122
- plt.ylim(lims)
 
123
 
124
  buf = io.BytesIO()
125
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
126
  buf.seek(0)
127
- plt.close()
128
-
129
- return Image.open(buf)
 
130
 
131
- def process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
132
- """Main processing function"""
 
133
  try:
134
- # Read uploaded files
135
- cfs = pd.read_excel(cashflow_base.name, index_col=0)
136
- cfs_lapse50 = pd.read_excel(cashflow_lapse.name, index_col=0)
137
- cfs_mort15 = pd.read_excel(cashflow_mort.name, index_col=0)
138
 
139
- pol_data = pd.read_excel(policy_data.name, index_col=0)
140
- if pol_data.shape[1] > 4:
141
- pol_data = pol_data[['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']]
142
-
143
- pvs = pd.read_excel(pv_base.name, index_col=0)
144
- pvs_lapse50 = pd.read_excel(pv_lapse.name, index_col=0)
145
- pvs_mort15 = pd.read_excel(pv_mort.name, index_col=0)
 
 
 
 
 
 
 
146
 
147
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
148
- pvs_list = [pvs, pvs_lapse50, pvs_mort15]
149
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
150
 
151
  results = {}
152
 
153
- # 1. Cashflow Calibration
 
 
154
  cluster_cfs = Clusters(cfs)
155
 
156
- # Cashflow comparison tables
157
- results['cf_base_table'] = cluster_cfs.compare_total(cfs)
158
- results['cf_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50)
159
- results['cf_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
160
-
161
- # Policy attributes analysis
162
- mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean'}
163
- results['cf_policy_attrs'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
164
 
165
- # Present value analysis
166
- results['cf_pv_base'] = cluster_cfs.compare_total(pvs)
167
- results['cf_pv_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
168
- results['cf_pv_mort'] = cluster_cfs.compare_total(pvs_mort15)
169
 
170
- # Create plots for cashflow calibration
171
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
172
- results['cf_scatter_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calibration - Base Scenario')
 
 
 
 
 
 
 
 
 
 
173
 
174
- # 2. Policy Attribute Calibration
175
- loc_vars = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
176
- cluster_attrs = Clusters(loc_vars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- results['attr_cf_base'] = cluster_attrs.compare_total(cfs)
179
- results['attr_policy_attrs'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
180
- results['attr_pv_base'] = cluster_attrs.compare_total(pvs)
181
- results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
182
- results['attr_scatter_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attribute Calibration - Base Scenario')
183
 
184
- # 3. Present Value Calibration
185
- cluster_pvs = Clusters(pvs)
 
186
 
187
- results['pv_cf_base'] = cluster_pvs.compare_total(cfs)
188
- results['pv_policy_attrs'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
189
- results['pv_pv_base'] = cluster_pvs.compare_total(pvs)
190
- results['pv_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
191
- results['pv_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
192
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
193
- results['pv_scatter_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'Present Value Calibration - Base Scenario')
 
 
 
 
 
 
 
194
 
195
- # Summary comparison plot
196
- fig, ax = plt.subplots(figsize=(12, 8))
197
- comparison_data = {
198
- 'Cashflow Calibration': [
199
- abs(cluster_cfs.compare_total(cfs)['error'].mean()),
200
- abs(cluster_cfs.compare_total(pvs)['error'].mean())
201
- ],
202
- 'Policy Attribute Calibration': [
203
- abs(cluster_attrs.compare_total(cfs)['error'].mean()),
204
- abs(cluster_attrs.compare_total(pvs)['error'].mean())
205
- ],
206
- 'Present Value Calibration': [
207
- abs(cluster_pvs.compare_total(cfs)['error'].mean()),
208
- abs(cluster_pvs.compare_total(pvs)['error'].mean())
209
- ]
210
- }
211
 
212
- x = np.arange(2)
213
- width = 0.25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- ax.bar(x - width, comparison_data['Cashflow Calibration'], width, label='Cashflow Calibration')
216
- ax.bar(x, comparison_data['Policy Attribute Calibration'], width, label='Policy Attribute Calibration')
217
- ax.bar(x + width, comparison_data['Present Value Calibration'], width, label='Present Value Calibration')
218
 
219
- ax.set_ylabel('Mean Absolute Error')
220
- ax.set_title('Calibration Method Comparison')
221
- ax.set_xticks(x)
222
- ax.set_xticklabels(['Cashflows', 'Present Values'])
223
- ax.legend()
224
- ax.grid(True, alpha=0.3)
225
 
226
- buf = io.BytesIO()
227
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
228
- buf.seek(0)
229
- plt.close()
230
- results['summary_plot'] = Image.open(buf)
231
 
232
  return results
233
 
 
 
 
 
 
 
234
  except Exception as e:
 
235
  return {"error": f"Error processing files: {str(e)}"}
236
 
 
237
  def create_interface():
238
- with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Soft()) as demo:
239
  gr.Markdown("""
240
  # Cluster Model Points Analysis
241
 
242
  This application applies cluster analysis to model point selection for insurance portfolios.
243
- Upload your Excel files to analyze cashflows, policy attributes, and present values using different calibration methods.
244
 
245
- **Required Files:**
246
- - 3 Cashflow files (Base, Lapse stress, Mortality stress scenarios)
247
- - 1 Policy data file
248
- - 3 Present value files (Base, Lapse stress, Mortality stress scenarios)
 
 
 
 
249
  """)
250
 
251
  with gr.Row():
252
- with gr.Column():
253
- gr.Markdown("### Upload Files")
254
- cashflow_base = gr.File(label="Cashflows - Base Scenario", file_types=[".xlsx"])
255
- cashflow_lapse = gr.File(label="Cashflows - Lapse Stress (+50%)", file_types=[".xlsx"])
256
- cashflow_mort = gr.File(label="Cashflows - Mortality Stress (+15%)", file_types=[".xlsx"])
257
- policy_data = gr.File(label="Policy Data", file_types=[".xlsx"])
258
- pv_base = gr.File(label="Present Values - Base Scenario", file_types=[".xlsx"])
259
- pv_lapse = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
260
- pv_mort = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
261
-
262
- analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
263
-
264
- with gr.Tabs():
265
- with gr.TabItem("Summary"):
266
- summary_plot = gr.Image(label="Calibration Methods Comparison")
267
 
268
- with gr.TabItem("Cashflow Calibration"):
269
- gr.Markdown("### Results using Annual Cashflows as Calibration Variables")
270
 
271
  with gr.Row():
272
- cf_base_table = gr.Dataframe(label="Base Scenario Comparison")
273
- cf_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
274
-
275
- cf_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
276
- cf_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
277
-
278
  with gr.Row():
279
- cf_pv_base = gr.Dataframe(label="Present Values - Base")
280
- cf_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
281
- cf_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
282
-
283
- with gr.TabItem("Policy Attribute Calibration"):
284
- gr.Markdown("### Results using Policy Attributes as Calibration Variables")
285
-
286
  with gr.Row():
287
- attr_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
288
- attr_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
289
 
290
- attr_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
291
- attr_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
292
- attr_pv_base = gr.Dataframe(label="Present Values - Base Scenario")
 
 
293
 
294
- with gr.TabItem("Present Value Calibration"):
295
- gr.Markdown("### Results using Present Values as Calibration Variables")
296
-
297
  with gr.Row():
298
- pv_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
299
- pv_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
300
-
301
- pv_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
302
- pv_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
303
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  with gr.Row():
305
- pv_pv_base = gr.Dataframe(label="Present Values - Base")
306
- pv_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
307
- pv_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
- def update_interface(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
310
- if not all([cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort]):
311
- return [None] * 17
 
 
 
312
 
313
- results = process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  if "error" in results:
316
- gr.Warning(results["error"])
317
- return [None] * 17
318
 
319
  return [
320
  results.get('summary_plot'),
321
- results.get('cf_base_table'),
322
- results.get('cf_policy_attrs'),
323
- results.get('cf_cashflow_plot'),
324
- results.get('cf_scatter_base'),
325
- results.get('cf_pv_base'),
326
- results.get('cf_pv_lapse'),
327
- results.get('cf_pv_mort'),
328
- results.get('attr_cf_base'),
329
- results.get('attr_policy_attrs'),
330
- results.get('attr_cashflow_plot'),
331
- results.get('attr_scatter_base'),
332
- results.get('attr_pv_base'),
333
- results.get('pv_cf_base'),
334
- results.get('pv_policy_attrs'),
335
- results.get('pv_cashflow_plot'),
336
- results.get('pv_scatter_base'),
337
- results.get('pv_pv_base'),
338
- results.get('pv_pv_lapse'),
339
- results.get('pv_pv_mort')
340
  ]
341
-
342
  analyze_btn.click(
343
- update_interface,
344
- inputs=[cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort],
345
- outputs=[
346
- summary_plot,
347
- cf_base_table, cf_policy_attrs, cf_cashflow_plot, cf_scatter_base,
348
- cf_pv_base, cf_pv_lapse, cf_pv_mort,
349
- attr_cf_base, attr_policy_attrs, attr_cashflow_plot, attr_scatter_base, attr_pv_base,
350
- pv_cf_base, pv_policy_attrs, pv_cashflow_plot, pv_scatter_base,
351
- pv_pv_base, pv_pv_lapse, pv_pv_mort
 
 
 
 
 
 
 
 
 
 
352
  ]
 
 
 
 
 
 
353
  )
354
-
355
  return demo
356
 
357
  if __name__ == "__main__":
358
- demo = create_interface()
359
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
+ import os # Added for path joining
10
  from PIL import Image
11
 
12
+ # Define the paths for example data
13
+ EXAMPLE_DATA_DIR = "eg_data"
14
+ EXAMPLE_FILES = {
15
+ "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
+ "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
+ "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
+ "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
19
+ "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
+ "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
+ "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
22
+ }
23
+
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
  self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
  closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
 
29
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
30
  rep_ids.name = 'policy_id'
31
  rep_ids.index.name = 'cluster_id'
32
  self.rep_ids = rep_ids
 
38
  temp = df.copy()
39
  temp['cluster_id'] = self.kmeans.labels_
40
  temp = temp.set_index('cluster_id')
41
+ agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
42
  return temp.groupby(temp.index).agg(agg)
43
 
44
  def extract_reps(self, df):
 
52
  if agg:
53
  cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+ # Ensure mult has same index as extract_reps(df) for proper alignment
56
+ extracted_df = self.extract_reps(df)
57
+ mult.index = extracted_df.index
58
+ return extracted_df.mul(mult)
59
  else:
60
  return self.extract_reps(df).mul(self.policy_count, axis=0)
61
 
 
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
+ # cols = df.columns # Not used
72
  op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
73
  actual = df.agg(op)
 
74
 
75
+ # For estimate, ensure aggregation ops are correctly applied *after* scaling
76
+ scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
77
+
78
+ # Corrected aggregation for estimate when 'mean' is involved
79
+ estimate_agg_ops = {}
80
+ for col_name, agg_type in op.items():
81
+ if agg_type == 'mean':
82
+ # Weighted average for mean columns
83
+ estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
84
+ else: # 'sum'
85
+ estimate_agg_ops[col_name] = 'sum'
86
+
87
+ # Need to handle the case where extract_and_scale_reps already applied scaling for sum
88
+ # The logic in extract_and_scale_reps is:
89
+ # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
90
+ # This means 'mean' columns are NOT multiplied by policy_count initially.
91
+
92
+ # Let's re-think the estimate aggregation for 'mean'
93
+ estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
94
+
95
+ final_estimate_ops = {}
96
+ for col, method in op.items():
97
+ if method == 'mean':
98
+ # For mean, we need the sum of (value * policy_count) / sum(policy_count)
99
+ # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
100
+ # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
101
+ # sum them up, and divide by total policy_count.
102
+ # However, the current extract_and_scale_reps scales 'mean' columns by 1.
103
+ # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
104
+
105
+ # Let's try to be more direct:
106
+ # Get the representative policies (unscaled for mean columns)
107
+ reps_unscaled_for_mean = self.extract_reps(df)
108
+ estimate_values = {}
109
+ for c in df.columns:
110
+ if op[c] == 'sum':
111
+ estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
112
+ elif op[c] == 'mean':
113
+ weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
114
+ total_weight = self.policy_count.sum()
115
+ estimate_values[c] = weighted_sum / total_weight if total_weight else 0
116
+ estimate = pd.Series(estimate_values)
117
+
118
+ else: # original 'sum' logic for all columns
119
+ final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
120
+ estimate = estimate_scaled.agg(final_estimate_ops)
121
+
122
+
123
+ else: # Original logic if no agg is specified (all sum)
124
  actual = df.sum()
125
  estimate = self.extract_and_scale_reps(df).sum()
126
 
127
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
131
  """Create cashflow comparison plots"""
132
+ if not cfs_list or not cluster_obj or not titles:
133
+ return None # Or a placeholder image
134
+ num_plots = len(cfs_list)
135
+ if num_plots == 0:
136
+ return None
137
+
138
+ # Determine subplot layout (e.g., 2x2 or adapt)
139
+ cols = 2
140
+ rows = (num_plots + cols - 1) // cols
141
+
142
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
143
  axes = axes.flatten()
144
 
145
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
146
  if i < len(axes):
147
  comparison = cluster_obj.compare_total(df)
148
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
149
+ axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
150
+ axes[i].set_ylabel('Value')
151
 
152
+ # Hide any unused subplots
153
+ for j in range(i + 1, len(axes)):
154
+ fig.delaxes(axes[j])
155
+
156
  plt.tight_layout()
157
  buf = io.BytesIO()
158
+ plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
159
  buf.seek(0)
160
+ img = Image.open(buf)
161
+ plt.close(fig) # Ensure figure is closed
162
+ return img
163
 
164
+ def plot_scatter_comparison(df_compare_output, title):
165
+ """Create scatter plot comparison from compare() output"""
166
+ if df_compare_output is None or df_compare_output.empty:
167
+ # Create a blank plot with a message
168
+ fig, ax = plt.subplots(figsize=(12, 8))
169
+ ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
170
+ ax.set_title(title)
171
+ buf = io.BytesIO()
172
+ plt.savefig(buf, format='png', dpi=100)
173
+ buf.seek(0)
174
+ img = Image.open(buf)
175
+ plt.close(fig)
176
+ return img
177
+
178
+ fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
179
 
180
+ if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
181
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
182
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
183
+ else:
184
+ unique_levels = df_compare_output.index.get_level_values(1).unique()
185
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
186
+
187
+ for item_level, color_val in zip(unique_levels, colors):
188
+ subset = df_compare_output.xs(item_level, level=1)
189
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
190
+ if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
191
+ ax.legend(title=df_compare_output.index.names[1])
192
+
193
+
194
+ ax.set_xlabel('Actual')
195
+ ax.set_ylabel('Estimate')
196
+ ax.set_title(title)
197
+ ax.grid(True)
198
 
199
  # Draw identity line
200
  lims = [
201
+ np.min([ax.get_xlim(), ax.get_ylim()]),
202
+ np.max([ax.get_xlim(), ax.get_ylim()]),
203
  ]
204
+ if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
205
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
206
+ ax.set_xlim(lims)
207
+ ax.set_ylim(lims)
208
 
209
  buf = io.BytesIO()
210
+ plt.savefig(buf, format='png', dpi=100)
211
  buf.seek(0)
212
+ img = Image.open(buf)
213
+ plt.close(fig)
214
+ return img
215
+
216
 
217
+ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
218
+ policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
219
+ """Main processing function - now accepts file paths"""
220
  try:
221
+ # Read uploaded files using paths
222
+ cfs = pd.read_excel(cashflow_base_path, index_col=0)
223
+ cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
224
+ cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
225
 
226
+ pol_data_full = pd.read_excel(policy_data_path, index_col=0)
227
+ # Ensure the correct columns are selected for pol_data
228
+ required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
229
+ if all(col in pol_data_full.columns for col in required_cols):
230
+ pol_data = pol_data_full[required_cols]
231
+ else:
232
+ # Fallback or error if columns are missing. For now, try to use as is or a subset.
233
+ gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
234
+ pol_data = pol_data_full
235
+
236
+
237
+ pvs = pd.read_excel(pv_base_path, index_col=0)
238
+ pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
239
+ pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
240
 
241
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
242
+ # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
243
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
244
 
245
  results = {}
246
 
247
+ mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
248
+
249
+ # --- 1. Cashflow Calibration ---
250
  cluster_cfs = Clusters(cfs)
251
 
252
+ results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
253
+ # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
254
+ # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
255
+
256
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
 
 
257
 
258
+ results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
259
+ results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
260
+ results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
261
 
 
262
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
263
+ results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
264
+ # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
265
+ # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
266
+
267
+ # --- 2. Policy Attribute Calibration ---
268
+ # Standardize policy attributes
269
+ if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
270
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
271
+ else:
272
+ gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
273
+ loc_vars_attrs = pol_data # or handle as an error/skip
274
 
275
+ if not loc_vars_attrs.empty:
276
+ cluster_attrs = Clusters(loc_vars_attrs)
277
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
278
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
279
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
280
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
281
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
282
+ # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
283
+
284
+ else: # Fill with None if skipped
285
+ results['attr_total_cf_base'] = pd.DataFrame()
286
+ results['attr_policy_attrs_total'] = pd.DataFrame()
287
+ results['attr_total_pv_base'] = pd.DataFrame()
288
+ results['attr_cashflow_plot'] = None
289
+ results['attr_scatter_cashflows_base'] = None
290
+
291
+
292
+ # --- 3. Present Value Calibration ---
293
+ cluster_pvs = Clusters(pvs)
294
 
295
+ results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
296
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
 
 
297
 
298
+ results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
299
+ results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
300
+ results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
301
 
 
 
 
 
 
302
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
303
+ results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
304
+ # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
305
+
306
+
307
+ # --- Summary Comparison Plot Data ---
308
+ # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
309
+ # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
310
+ # For simplicity, using mean of the 'error' column from compare_total for key metrics
311
 
312
+ error_data = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ # Cashflow Calibration Errors
315
+ if 'PV_NetCF' in pvs.columns:
316
+ err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
317
+ err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
318
+ err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
319
+ error_data['CF Calib. (PV NetCF)'] = [
320
+ abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
321
+ ]
322
+ else: # Fallback if PV_NetCF is not present
323
+ error_data['CF Calib. (PV NetCF)'] = [
324
+ abs(cluster_cfs.compare_total(pvs)['error'].mean()),
325
+ abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
326
+ abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
327
+ ]
328
+
329
+
330
+ # Policy Attribute Calibration Errors
331
+ if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
332
+ err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
333
+ err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
334
+ err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
335
+ error_data['Attr Calib. (PV NetCF)'] = [
336
+ abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
337
+ ]
338
+ else:
339
+ error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
340
+
341
+
342
+ # Present Value Calibration Errors
343
+ if 'PV_NetCF' in pvs.columns:
344
+ err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
345
+ err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
346
+ err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
347
+ error_data['PV Calib. (PV NetCF)'] = [
348
+ abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
349
+ ]
350
+ else:
351
+ error_data['PV Calib. (PV NetCF)'] = [
352
+ abs(cluster_pvs.compare_total(pvs)['error'].mean()),
353
+ abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
354
+ abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
355
+ ]
356
 
357
+ # Create Summary Plot
358
+ summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
 
359
 
360
+ fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
361
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
362
+ ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
363
+ ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
364
+ ax_summary.tick_params(axis='x', rotation=0)
365
+ plt.tight_layout()
366
 
367
+ buf_summary = io.BytesIO()
368
+ plt.savefig(buf_summary, format='png', dpi=100)
369
+ buf_summary.seek(0)
370
+ results['summary_plot'] = Image.open(buf_summary)
371
+ plt.close(fig_summary)
372
 
373
  return results
374
 
375
+ except FileNotFoundError as e:
376
+ gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
377
+ return {"error": f"File not found: {e.filename}"}
378
+ except KeyError as e:
379
+ gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
380
+ return {"error": f"Missing column: {e}"}
381
  except Exception as e:
382
+ gr.Error(f"Error processing files: {str(e)}")
383
  return {"error": f"Error processing files: {str(e)}"}
384
 
385
+
386
  def create_interface():
387
+ with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
388
  gr.Markdown("""
389
  # Cluster Model Points Analysis
390
 
391
  This application applies cluster analysis to model point selection for insurance portfolios.
392
+ Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
393
 
394
+ **Required Files (Excel .xlsx):**
395
+ - Cashflows - Base Scenario
396
+ - Cashflows - Lapse Stress (+50%)
397
+ - Cashflows - Mortality Stress (+15%)
398
+ - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
399
+ - Present Values - Base Scenario
400
+ - Present Values - Lapse Stress
401
+ - Present Values - Mortality Stress
402
  """)
403
 
404
  with gr.Row():
405
+ with gr.Column(scale=1):
406
+ gr.Markdown("### Upload Files or Load Examples")
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+ load_example_btn = gr.Button("Load Example Data")
 
409
 
410
  with gr.Row():
411
+ cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
412
+ cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
413
+ cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
 
 
 
414
  with gr.Row():
415
+ policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
416
+ pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
417
+ pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
 
 
 
 
418
  with gr.Row():
419
+ pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
 
420
 
421
+ analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
422
+
423
+ with gr.Tabs():
424
+ with gr.TabItem("📊 Summary"):
425
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
426
 
427
+ with gr.TabItem("💸 Cashflow Calibration"):
428
+ gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
 
429
  with gr.Row():
430
+ cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
431
+ cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
432
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
433
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
434
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
435
+ with gr.Row():
436
+ cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
437
+ cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
438
+ cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
439
+
440
+ with gr.TabItem("👤 Policy Attribute Calibration"):
441
+ gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
442
+ with gr.Row():
443
+ attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
444
+ attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
445
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
446
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
447
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
448
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
449
+
450
+ with gr.TabItem("💰 Present Value Calibration"):
451
+ gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
452
  with gr.Row():
453
+ pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
454
+ pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
455
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
456
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
457
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
458
+ with gr.Row():
459
+ pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
460
+ pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
461
+ pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
462
+
463
+ # --- Helper function to prepare outputs ---
464
+ def get_all_output_components():
465
+ return [
466
+ summary_plot_output,
467
+ # Cashflow Calib Outputs
468
+ cf_total_base_table_out, cf_policy_attrs_total_out,
469
+ cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
470
+ cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
471
+ # Attribute Calib Outputs
472
+ attr_total_cf_base_out, attr_policy_attrs_total_out,
473
+ attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
474
+ # PV Calib Outputs
475
+ pv_total_cf_base_out, pv_policy_attrs_total_out,
476
+ pv_cashflow_plot_out, pv_scatter_pvs_base_out,
477
+ pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
478
+ ]
479
 
480
+ # --- Action for Analyze Button ---
481
+ def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
482
+ # Ensure all files are provided (either by upload or example load)
483
+ files = [f1, f2, f3, f4, f5, f6, f7]
484
+ # Gradio File objects have a .name attribute for the temp path
485
+ # If they are already strings (from example load), they are paths
486
 
487
+ file_paths = []
488
+ for i, f_obj in enumerate(files):
489
+ if f_obj is None:
490
+ gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
491
+ # Return Nones for all output components
492
+ return [None] * len(get_all_output_components())
493
+
494
+ # If f_obj is a Gradio FileData object (from direct upload)
495
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
496
+ file_paths.append(f_obj.name)
497
+ # If f_obj is already a string path (from example load)
498
+ elif isinstance(f_obj, str):
499
+ file_paths.append(f_obj)
500
+ else:
501
+ gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
502
+ return [None] * len(get_all_output_components())
503
+
504
+
505
+ results = process_files(*file_paths)
506
 
507
  if "error" in results:
508
+ # Error already displayed by process_files or here
509
+ return [None] * len(get_all_output_components())
510
 
511
  return [
512
  results.get('summary_plot'),
513
+ # CF Calib
514
+ results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
515
+ results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
516
+ results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
517
+ # Attr Calib
518
+ results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
519
+ results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
520
+ # PV Calib
521
+ results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
522
+ results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
523
+ results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
 
 
 
 
 
 
 
 
524
  ]
525
+
526
  analyze_btn.click(
527
+ handle_analysis,
528
+ inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
529
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
530
+ outputs=get_all_output_components()
531
+ )
532
+
533
+ # --- Action for Load Example Data Button ---
534
+ def load_example_files():
535
+ # Check if all example files exist
536
+ missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
537
+ if missing_files:
538
+ gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
539
+ return [None] * 7 # Return Nones for all file inputs
540
+
541
+ gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
542
+ return [
543
+ EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
544
+ EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
545
+ EXAMPLE_FILES["pv_mort"]
546
  ]
547
+
548
+ load_example_btn.click(
549
+ load_example_files,
550
+ inputs=[],
551
+ outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
552
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
553
  )
554
+
555
  return demo
556
 
557
  if __name__ == "__main__":
558
+ # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
559
+ if not os.path.exists(EXAMPLE_DATA_DIR):
560
+ os.makedirs(EXAMPLE_DATA_DIR)
561
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
562
+ # You might want to add dummy files here for basic testing if the real files aren't present
563
+ # For example:
564
+ # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
565
+ # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
566
+ # It's better to instruct the user to add the actual files.
567
+ print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
568
+
569
+
570
+ demo_app = create_interface()
571
+ demo_app.launch()