alidenewade commited on
Commit
dd97346
·
verified ·
1 Parent(s): 8e2e740

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +410 -347
app.py CHANGED
@@ -2,11 +2,11 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
- import os
10
  from PIL import Image
11
 
12
  # Define the paths for example data
@@ -23,258 +23,267 @@ EXAMPLE_FILES = {
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
 
26
  if loc_vars.empty:
27
  raise ValueError("Input data for KMeans (loc_vars) is empty.")
28
  if loc_vars.isnull().all().all():
29
  raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.")
30
 
31
- # Ensure n_clusters does not exceed the number of samples
32
- n_samples = len(loc_vars)
33
- n_clusters_to_use = min(1000, n_samples)
34
- if n_clusters_to_use == 0 : # Should be caught by loc_vars.empty already
35
- raise ValueError("Cannot determine n_clusters as no samples are available.")
36
-
37
-
38
- self.kmeans = KMeans(n_clusters=n_clusters_to_use, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
39
  closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
40
 
41
- rep_ids = pd.Series(data=(closest + 1))
42
  rep_ids.name = 'policy_id'
43
  rep_ids.index.name = 'cluster_id'
44
  self.rep_ids = rep_ids
45
 
46
- # Handle case where loc_vars might be shorter than kmeans.labels_ if n_samples was 0 initially (though guarded)
47
- if n_samples > 0:
48
- self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * n_samples}))['policy_count']
49
- else: # Should not be reached due to earlier checks
50
- self.policy_count = pd.Series(dtype=int).rename_axis('cluster_id')
51
-
52
 
53
  def agg_by_cluster(self, df, agg=None):
54
  temp = df.copy()
55
- if len(self.kmeans.labels_) != len(df):
56
- # This can happen if df is empty or mismatched with loc_vars length during __init__
57
- # Or if called with a df of different length than used for fitting KMeans
58
- gr.Warning(f"Length mismatch in agg_by_cluster: kmeans.labels_ ({len(self.kmeans.labels_)}) vs df ({len(df)}). Results may be incorrect.")
59
- # Fallback: return an empty df with expected structure or raise error
60
- if 'cluster_id' not in df.columns: # if df doesn't have cluster_id, we can't group
61
- return df.groupby(None).agg(agg if isinstance(agg, dict) else 'sum') # will likely be empty or error
62
-
63
-
64
- temp['cluster_id'] = self.kmeans.labels_[:len(df)] # Ensure labels don't exceed df length
65
  temp = temp.set_index('cluster_id')
66
 
67
- agg_ops = {}
68
- if isinstance(agg, dict):
 
 
 
 
 
69
  agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns}
70
- else: # agg is None or not a dict (e.g. "sum")
71
- for col in temp.columns:
72
- if pd.api.types.is_numeric_dtype(temp[col]):
73
- agg_ops[col] = 'sum' # Default to sum for numeric
74
- if not agg_ops and isinstance(agg, str) : # e.g. agg = "sum"
75
- return temp.groupby(temp.index).agg(agg)
76
 
77
  return temp.groupby(temp.index).agg(agg_ops)
78
 
79
-
80
  def extract_reps(self, df):
81
- # Ensure df has 'policy_id' if it's going to be reset and merged on.
82
- # The input df to this method is typically the original data (cfs, pol_data, pvs) which has policy_id as index.
83
- # df.reset_index() will move 'policy_id' (or current index name) to a column.
84
- # Let's ensure the column name is consistently 'policy_id' after reset_index.
85
- df_reset = df.reset_index()
86
- original_index_name = df.index.name if df.index.name else 'index' # Default if no name
87
- if 'policy_id' not in df_reset.columns and original_index_name in df_reset.columns:
88
- df_reset = df_reset.rename(columns={original_index_name: 'policy_id'})
89
- elif 'policy_id' not in df_reset.columns : # Still no policy_id
90
- gr.Error("Could not find 'policy_id' column for merging in extract_reps.")
91
- # Return an empty DataFrame with expected structure or raise error
92
- # For now, let it proceed; merge might fail or produce unexpected results.
93
- # This indicates an issue with input data structure.
94
-
95
- temp = pd.merge(self.rep_ids, df_reset, how='left', on='policy_id')
96
- temp.index.name = 'cluster_id' # The index of rep_ids becomes the new index
97
- if 'policy_id' in temp.columns:
98
- return temp.drop('policy_id', axis=1)
99
- return temp
100
-
101
 
102
  def extract_and_scale_reps(self, df, agg=None):
103
  extracted_df = self.extract_reps(df)
104
  if extracted_df.empty:
105
- return extracted_df
106
-
107
- scaled_df = extracted_df.copy()
108
- # Ensure policy_count index is aligned with scaled_df (which is cluster_id)
109
- policy_count_aligned = self.policy_count.reindex(scaled_df.index).fillna(0)
110
 
111
  if agg and isinstance(agg, dict):
 
 
 
 
 
 
112
  for c in extracted_df.columns:
113
- if pd.api.types.is_numeric_dtype(extracted_df[c]): # Only scale numeric columns
114
- if agg.get(c, 'sum') == 'sum':
115
- scaled_df[c] = extracted_df[c].mul(policy_count_aligned, axis=0)
116
- else: # Default: scale all numeric columns by policy_count
117
- for c in extracted_df.columns:
118
- if pd.api.types.is_numeric_dtype(extracted_df[c]):
119
- scaled_df[c] = extracted_df[c].mul(policy_count_aligned, axis=0)
120
- return scaled_df
121
 
122
  def compare(self, df, agg=None):
123
  source = self.agg_by_cluster(df, agg)
 
124
 
125
- # For target, we need representative values, scaled appropriately for 'sum' or raw for 'mean' per cluster
126
- target_reps = self.extract_reps(df) # These are the raw representative values per cluster
127
-
128
- # If agg defines means, those are the target estimates per cluster.
129
- # If agg defines sums, target estimates are rep_value * policy_count.
130
- target_estimates_per_cluster = target_reps.copy()
131
- policy_count_aligned = self.policy_count.reindex(target_reps.index).fillna(0)
132
-
133
- if isinstance(agg, dict):
134
  for col, method in agg.items():
135
- if col in target_estimates_per_cluster.columns and method == 'sum':
136
- if pd.api.types.is_numeric_dtype(target_estimates_per_cluster[col]):
137
- target_estimates_per_cluster[col] = target_reps[col].mul(policy_count_aligned, axis=0)
138
- elif not agg: # Default to sum if agg is None
139
- for col in target_estimates_per_cluster.columns:
140
- if pd.api.types.is_numeric_dtype(target_estimates_per_cluster[col]):
141
- target_estimates_per_cluster[col] = target_reps[col].mul(policy_count_aligned, axis=0)
142
-
143
- # Align source and target_estimates_per_cluster before stacking
144
- # Both should have 'cluster_id' as index and data columns
145
- aligned_source, aligned_target = source.align(target_estimates_per_cluster, join='inner', axis=0) # Align rows (clusters)
146
- aligned_source, aligned_target = aligned_source.align(aligned_target, join='inner', axis=1) # Align columns
147
-
148
- return pd.DataFrame({'actual': aligned_source.stack(), 'estimate': aligned_target.stack()})
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def compare_total(self, df, agg=None):
 
152
  if df.empty:
153
  return pd.DataFrame(columns=['actual', 'estimate', 'error'])
154
 
 
155
  op_for_actual = {}
156
  if isinstance(agg, dict):
157
  for c in df.columns:
158
- op_for_actual[c] = agg.get(c, 'sum')
159
- else:
160
  for c in df.columns:
161
  if pd.api.types.is_numeric_dtype(df[c]):
162
  op_for_actual[c] = 'sum'
163
-
164
- actual = df.agg(op_for_actual).dropna()
 
 
165
 
166
- reps_values = self.extract_reps(df)
167
- if reps_values.empty or self.policy_count.empty:
168
- estimate = pd.Series(index=actual.index, dtype=float).fillna(np.nan)
 
169
  else:
170
  estimate_values = {}
171
- policy_count_aligned = self.policy_count.reindex(reps_values.index).fillna(0)
172
- total_weight = policy_count_aligned.sum()
173
-
174
- for col_name in actual.index:
175
- col_op = op_for_actual.get(col_name)
176
- if col_name not in reps_values.columns or not pd.api.types.is_numeric_dtype(reps_values[col_name]):
177
  estimate_values[col_name] = np.nan
178
  continue
179
-
180
  rep_col_values = reps_values[col_name]
 
181
  if col_op == 'sum':
182
- estimate_values[col_name] = (rep_col_values * policy_count_aligned).sum()
 
183
  elif col_op == 'mean':
184
- weighted_sum = (rep_col_values * policy_count_aligned).sum()
 
 
185
  estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan
186
- else:
187
  estimate_values[col_name] = np.nan
188
- estimate = pd.Series(estimate_values, index=actual.index)
189
-
 
 
 
190
  actual_aligned, estimate_aligned = actual.align(estimate, join='inner')
 
191
  error = pd.Series(index=actual_aligned.index, dtype=float)
 
 
192
  valid_mask = (actual_aligned != 0) & (~actual_aligned.isna())
193
  error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1
 
 
194
  actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna())
 
195
  error[actual_zero_mask & (estimate_aligned == 0)] = 0
196
- error[actual_zero_mask & (estimate_aligned != 0) & (~estimate_aligned.isna())] = np.inf
 
 
 
197
  error = error.replace([np.inf, -np.inf], np.nan)
198
 
199
- return pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error})
 
200
 
201
 
202
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
203
  if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0:
204
- fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No data for cashflow plot.", ha='center', va='center')
 
205
  buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
206
 
207
  num_plots = len(cfs_list)
208
- cols = min(2, num_plots) # Max 2 columns
209
  rows = (num_plots + cols - 1) // cols
210
 
211
- fig, axes = plt.subplots(rows, cols, figsize=(7.5 * cols, 5 * rows), squeeze=False)
212
  axes = axes.flatten()
 
213
  plot_made = False
214
-
215
  for i, (df_cf, title) in enumerate(zip(cfs_list, titles)):
216
  if i < len(axes):
217
- ax_curr = axes[i]
218
  if df_cf is None or df_cf.empty:
219
- ax_curr.text(0.5,0.5, f"No data for\n{title}", ha='center', va='center', wrap=True); ax_curr.set_title(title)
 
220
  continue
221
- try:
222
- comparison = cluster_obj.compare_total(df_cf)
223
- if not comparison.empty and 'actual' in comparison and 'estimate' in comparison:
224
- comparison[['actual', 'estimate']].plot(ax=ax_curr, grid=True, title=title)
225
- ax_curr.set_xlabel('Time Period')
226
- ax_curr.set_ylabel('Cashflow Value')
227
- plot_made = True
228
- else:
229
- ax_curr.text(0.5,0.5, f"Could not generate\ncomparison for {title}", ha='center', va='center', wrap=True); ax_curr.set_title(title)
230
- except Exception as e:
231
- ax_curr.text(0.5,0.5, f"Error plotting {title}:\n{str(e)[:50]}...", ha='center', va='center', wrap=True); ax_curr.set_title(title)
232
 
233
- for j in range(i + 1, len(axes)): fig.delaxes(axes[j])
234
- if not plot_made:
235
- plt.close(fig); fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No cashflow plots generated.", ha='center', va='center')
 
 
 
 
 
236
 
237
- plt.tight_layout(pad=2.0)
238
- buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=90); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
 
 
 
 
 
239
 
240
  def plot_scatter_comparison(df_compare_output, title):
241
  if df_compare_output is None or df_compare_output.empty:
242
- fig, ax = plt.subplots(figsize=(8,5)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title)
243
  buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
244
 
245
- fig, ax = plt.subplots(figsize=(8, 5))
246
 
247
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
 
248
  ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6)
249
  else:
250
  unique_levels = df_compare_output.index.get_level_values(1).unique()
251
- if len(unique_levels) == 0 : # No data after all
252
- ax.text(0.5, 0.5, "No data points for scatter.", ha='center', va='center')
253
- else:
254
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
255
- for item_level, color_val in zip(unique_levels, colors):
256
- subset = df_compare_output.xs(item_level, level=1)
257
- if not subset.empty:
258
- ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level))
259
- if len(unique_levels) > 1 and len(unique_levels) <=10:
260
- ax.legend(title=str(df_compare_output.index.names[1]), fontsize='small')
261
-
262
- ax.set_xlabel('Actual Value')
263
- ax.set_ylabel('Estimated Value')
264
- ax.set_title(title, fontsize='medium')
265
- ax.grid(True, linestyle='--', alpha=0.7)
266
 
267
  try:
268
- current_xlim = ax.get_xlim(); current_ylim = ax.get_ylim()
269
- if np.isfinite(current_xlim).all() and np.isfinite(current_ylim).all(): # Check if limits are valid
270
- lims = [np.nanmin([current_xlim, current_ylim]), np.nanmax([current_xlim, current_ylim])]
271
- if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]):
272
- ax.plot(lims, lims, 'r-', linewidth=1, alpha=0.8, dashes=(2,2))
273
- ax.set_xlim(lims); ax.set_ylim(lims)
274
- except Exception: pass
 
 
 
 
 
275
 
276
- plt.tight_layout(pad=1.5)
277
- buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=90); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
 
 
 
 
 
278
 
279
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
280
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
@@ -285,13 +294,13 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
285
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
286
 
287
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
288
- required_policy_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
289
- missing_policy_cols = [col for col in required_policy_cols if col not in pol_data_full.columns]
290
  if missing_policy_cols:
291
- gr.Warning(f"Policy data missing: {', '.join(missing_policy_cols)}.")
292
- pol_data = pol_data_full
293
  else:
294
- pol_data = pol_data_full[required_policy_cols]
295
 
296
  pvs = pd.read_excel(pv_base_path, index_col=0)
297
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
@@ -299,242 +308,296 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
299
 
300
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
301
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
 
302
  mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
303
 
304
- # --- Calibrations ---
305
- gr.Info("Processing calibrations...")
306
- cluster_cfs = Clusters(cfs) if not cfs.empty else None
307
- if cluster_cfs:
308
- results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
309
- results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg)
310
- results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
311
- results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
312
- results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
313
- results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
314
- results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)')
315
- else: gr.Warning("Cashflow Calibration skipped due to empty base cashflow data.")
316
-
317
- if not pol_data.empty:
318
- pol_data_min = pol_data.min(); pol_data_range = pol_data.max() - pol_data_min
319
- pol_data_range[pol_data_range == 0] = 1
320
- loc_vars_attrs = ((pol_data - pol_data_min) / pol_data_range).fillna(0)
321
- cluster_attrs = Clusters(loc_vars_attrs) if not loc_vars_attrs.empty else None
322
- else: cluster_attrs = None; gr.Warning("Policy Attribute Calibration skipped due to empty policy data.")
323
-
324
- if cluster_attrs:
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
326
  results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg)
327
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
328
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
329
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)')
330
-
331
- cluster_pvs = Clusters(pvs) if not pvs.empty else None
332
- if cluster_pvs:
333
- results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
334
- results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg)
335
- results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
336
- results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
337
- results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
338
- results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
339
- results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
340
- else: gr.Warning("PV Calibration skipped due to empty base PV data.")
 
 
 
 
 
 
341
 
342
- # --- Summary Plot ---
343
  gr.Info("Generating Summary Plot...")
344
  error_data = {}
345
- pv_col_name = 'PV_NetCF'
346
- calibration_objects = [
347
- ("CF Calib.", cluster_cfs),
348
- ("Attr Calib.", cluster_attrs if 'cluster_attrs' in locals() else None),
349
- ("PV Calib.", cluster_pvs)
350
- ]
351
 
352
- for calib_name_display, cluster_obj in calibration_objects:
 
 
 
 
353
  current_calib_errors = []
354
- if cluster_obj is None:
355
  current_calib_errors = [np.nan, np.nan, np.nan]
356
  else:
357
  for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]:
358
- if pv_df_scenario.empty: current_calib_errors.append(np.nan); continue
 
 
 
359
  comp_total_df = cluster_obj.compare_total(pv_df_scenario)
360
- error_val = np.nan
361
- if not comp_total_df.empty:
362
- if pv_col_name in comp_total_df.index: error_val = comp_total_df.loc[pv_col_name, 'error']
363
- elif 'error' in comp_total_df.columns: error_val = comp_total_df['error'].mean()
 
 
 
 
364
  current_calib_errors.append(abs(error_val))
365
  error_data[calib_name_display] = current_calib_errors
366
-
367
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
 
368
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
369
- plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}'
370
 
371
- if summary_df.isnull().all().all() or summary_df.empty:
372
- ax_summary.text(0.5, 0.5, f"Summary error data N/A.\nCheck PV files for '{pv_col_name}' & valid data.",
 
373
  ha='center', va='center', transform=ax_summary.transAxes, wrap=True)
 
 
 
 
374
  else:
375
- summary_df.plot(kind='bar', ax=ax_summary, grid=True, width=0.8)
376
  ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)')
 
377
  ax_summary.tick_params(axis='x', rotation=0)
378
- ax_summary.set_title(plot_title)
379
- plt.tight_layout(pad=1.5)
380
- buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=90); buf_summary.seek(0)
381
- results['summary_plot'] = Image.open(buf_summary); plt.close(fig_summary)
382
 
383
- # Round all DataFrame results to 2 decimal places
384
- for key, value in results.items():
385
- if isinstance(value, pd.DataFrame):
386
- try:
387
- results[key] = value.round(2)
388
- except (TypeError, AttributeError) as e: # Non-numeric data in df
389
- gr.Debug(f"Could not round DataFrame for key '{key}': {e}")
390
-
391
-
392
- gr.Info("All processing complete. ✅")
393
  return results
394
 
395
- except FileNotFoundError as e: gr.Error(f"File not found: {e.filename}."); return {"error": str(e)}
396
- except ValueError as e: gr.Error(f"Data error: {str(e)}"); return {"error": str(e)}
397
- except KeyError as e: gr.Error(f"Missing column: {e}. Check data formats."); return {"error": str(e)}
 
 
 
 
 
 
398
  except Exception as e:
399
- gr.Error(f"Unexpected error: {str(e)}"); import traceback; traceback.print_exc()
400
- return {"error": str(e)}
 
 
401
 
402
- def create_interface():
403
- with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Default()) as demo: # Explicitly default theme
404
- gr.Markdown("# Cluster Model Points Analysis wybrać") # smaller heading
405
- gr.Markdown(
406
- "Applies k-means cluster analysis to select representative model points from an insurance portfolio. "
407
- "Upload Excel files or use example data to analyze results using different calibration variables."
408
- )
409
- with gr.Accordion("📚 Instructions & File Requirements", open=False):
410
- gr.Markdown(
411
- """
412
- **Required Excel (.xlsx) Files:**
413
- 1. **Cashflows - Base Scenario**: Net annual cashflows (index: policy_id, columns: time periods).
414
- 2. **Cashflows - Lapse Stress (+50%)**: Same format as Base.
415
- 3. **Cashflows - Mortality Stress (+15%)**: Same format as Base.
416
- 4. **Policy Data**: Attributes for each policy (index: policy_id). Must include columns: `age_at_entry`, `policy_term`, `sum_assured`, `duration_mth`.
417
- 5. **Present Values - Base Scenario**: PVs of cashflow components (index: policy_id). Ideally include `PV_NetCF`.
418
- 6. **Present Values - Lapse Stress**: Same format as Base PV.
419
- 7. **Present Values - Mortality Stress**: Same format as Base PV.
420
-
421
- Ensure all files have a common `policy_id` that can be used as the index (set `index_col=0` when reading if policy_id is the first column).
422
- """
423
- )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  with gr.Row():
426
- with gr.Column(scale=3): # Give more space to file inputs
427
  gr.Markdown("### 📂 Upload Files or Load Examples")
 
428
  with gr.Row():
429
- cashflow_base_input = gr.File(label="CF Base", file_types=[".xlsx"], scale=1)
430
- cashflow_lapse_input = gr.File(label="CF Lapse Str.", file_types=[".xlsx"], scale=1)
431
- cashflow_mort_input = gr.File(label="CF Mort Str.", file_types=[".xlsx"], scale=1)
432
  with gr.Row():
433
- policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"], scale=1)
434
- pv_base_input = gr.File(label="PV Base", file_types=[".xlsx"], scale=1)
435
- pv_lapse_input = gr.File(label="PV Lapse Str.", file_types=[".xlsx"], scale=1)
436
  with gr.Row():
437
- pv_mort_input = gr.File(label="PV Mort Str.", file_types=[".xlsx"], scale=1)
438
- # Keep buttons in a separate row or column for better control
439
- with gr.Column(scale=1, min_width=200): # Column for buttons
440
- gr.Markdown("ã…¤") # Spacer
441
- load_example_btn = gr.Button("Load Example Data", icon="💾", elem_id="load-button")
442
- analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", elem_id="analyze-button")
443
 
444
  with gr.Tabs():
445
- with gr.TabItem("📊 Summary", id="summary_tab"):
446
  summary_plot_output = gr.Image(label="Calibration Methods Comparison")
447
 
448
- tab_items_data = [
449
- ("💸 CF Calib.", "cf", "Annual Cashflows (Base)"),
450
- ("👤 Attr Calib.", "attr", "Policy Attributes"),
451
- ("💰 PV Calib.", "pv", "Present Values (Base)")
452
- ]
 
 
 
 
 
 
 
453
 
454
- for tab_name, prefix, calib_vars_desc in tab_items_data:
455
- with gr.TabItem(tab_name, id=f"{prefix}_calib_tab"):
456
- gr.Markdown(f"### Results: Using {calib_vars_desc} as Calibration Variables")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  with gr.Row():
458
- globals()[f"{prefix}_total_base_table_out"] = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True, height=250)
459
- globals()[f"{prefix}_policy_attrs_total_out"] = gr.Dataframe(label="Overall Comparison - Policy Attr.", wrap=True, height=250)
460
-
461
- globals()[f"{prefix}_cashflow_plot_out"] = gr.Image(label="Cashflow Value Comparisons")
462
-
463
- scatter_label = "Scatter: Per-Cluster PVs (Base)" if prefix == "pv" else "Scatter: Per-Cluster CFs (Base)"
464
- globals()[f"{prefix}_scatter_display_out"] = gr.Image(label=scatter_label)
465
-
466
- with gr.Accordion("Present Value Comparisons (Totals)", open=False):
467
- with gr.Row():
468
- globals()[f"{prefix}_pv_total_base_out"] = gr.Dataframe(label="PVs - Base", wrap=True, height=250)
469
- if prefix != "attr": # Attr calib only shows base PV for brevity in original design
470
- globals()[f"{prefix}_pv_total_lapse_out"] = gr.Dataframe(label="PVs - Lapse Stress", wrap=True, height=250)
471
- globals()[f"{prefix}_pv_total_mort_out"] = gr.Dataframe(label="PVs - Mortality Stress", wrap=True, height=250)
472
-
473
- # Define all output components dynamically based on tab_items_data
474
- output_components = [summary_plot_output]
475
- for _, prefix, _ in tab_items_data:
476
- output_components.extend([
477
- globals()[f"{prefix}_total_base_table_out"], globals()[f"{prefix}_policy_attrs_total_out"],
478
- globals()[f"{prefix}_cashflow_plot_out"], globals()[f"{prefix}_scatter_display_out"],
479
- globals()[f"{prefix}_pv_total_base_out"]
480
- ])
481
- if prefix != "attr":
482
- output_components.extend([
483
- globals()[f"{prefix}_pv_total_lapse_out"], globals()[f"{prefix}_pv_total_mort_out"]
484
- ])
485
-
486
- input_file_components = [
487
- cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
488
- policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input
489
  ]
490
-
491
- def handle_analysis_click(*files_input): # Use *args
492
- if not all(f is not None for f in files_input):
493
- gr.Warning("Not all files provided. Please upload/load all 7 files.")
494
- return [None] * len(output_components)
 
495
 
 
496
  file_paths = []
497
- for f_obj in files_input:
498
- if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): file_paths.append(f_obj.name)
499
- elif isinstance(f_obj, str): file_paths.append(f_obj)
500
- else: gr.Error(f"Invalid file input: {f_obj}."); return [None] * len(output_components)
 
 
 
 
501
 
502
  analysis_results = process_files(*file_paths)
503
- if "error" in analysis_results: return [None] * len(output_components)
 
 
504
 
505
  # Map results to output components
506
- output_values = [analysis_results.get('summary_plot')]
507
- for _, prefix, _ in tab_items_data:
508
- output_values.extend([
509
- analysis_results.get(f'{prefix}_total_base_table'),
510
- analysis_results.get(f'{prefix}_policy_attrs_total'),
511
- analysis_results.get(f'{prefix}_cashflow_plot'),
512
- analysis_results.get(f'{prefix}_scatter_{"pvs" if prefix == "pv" else "cashflows"}_base'), # Match key used in process_files
513
- analysis_results.get(f'{prefix}_pv_total_base')
514
- ])
515
- if prefix != "attr":
516
- output_values.extend([
517
- analysis_results.get(f'{prefix}_pv_total_lapse'),
518
- analysis_results.get(f'{prefix}_pv_total_mort')
519
- ])
520
- return output_values
521
-
522
- analyze_btn.click(handle_analysis_click, inputs=input_file_components, outputs=output_components)
 
523
 
 
 
 
 
524
  def load_example_files_action():
525
- missing = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
526
- if missing: gr.Error(f"Missing example files: {', '.join(missing)}."); return [None] * 7
527
- gr.Info(f"Example data loaded. Click 'Analyze Dataset'.")
528
- return list(EXAMPLE_FILES.values())
529
- load_example_btn.click(load_example_files_action, outputs=input_file_components)
530
-
 
 
 
 
 
531
  return demo
532
 
533
  if __name__ == "__main__":
534
  if not os.path.exists(EXAMPLE_DATA_DIR):
535
- try: os.makedirs(EXAMPLE_DATA_DIR); print(f"Created '{EXAMPLE_DATA_DIR}'. Place example files there.")
536
- except OSError as e: print(f"Error creating {EXAMPLE_DATA_DIR}: {e}. Please create manually.")
 
 
 
 
537
 
538
- print("Starting Gradio application... Ensure example files are in './eg_data/'")
 
 
 
 
539
  demo_app = create_interface()
540
  demo_app.launch()
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the final Gradio app logic
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
+ import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
 
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
+ # Ensure loc_vars is not empty before fitting KMeans
27
  if loc_vars.empty:
28
  raise ValueError("Input data for KMeans (loc_vars) is empty.")
29
  if loc_vars.isnull().all().all():
30
  raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.")
31
 
32
+ self.kmeans = KMeans(n_clusters=min(1000, len(loc_vars)), random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
 
 
 
 
 
 
 
33
  closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
34
 
35
+ rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
36
  rep_ids.name = 'policy_id'
37
  rep_ids.index.name = 'cluster_id'
38
  self.rep_ids = rep_ids
39
 
40
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
 
 
 
 
 
41
 
42
  def agg_by_cluster(self, df, agg=None):
43
  temp = df.copy()
44
+ temp['cluster_id'] = self.kmeans.labels_
 
 
 
 
 
 
 
 
 
45
  temp = temp.set_index('cluster_id')
46
 
47
+ # Ensure agg is a dictionary if not None
48
+ if agg is not None and not isinstance(agg, dict):
49
+ # Assuming if agg is not a dict, it's the default "sum" for all, which is handled by else.
50
+ # This case might need specific handling if agg can be other types.
51
+ # For now, if it's not a dict, treat as if no specific agg ops were given for columns.
52
+ agg_ops = {col: "sum" for col in temp.columns} # Default to sum if agg format is unexpected
53
+ elif isinstance(agg, dict):
54
  agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns}
55
+ else: # agg is None
56
+ agg_ops = "sum" # Pandas groupby will apply sum to all numeric columns
 
 
 
 
57
 
58
  return temp.groupby(temp.index).agg(agg_ops)
59
 
 
60
  def extract_reps(self, df):
61
+ temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
62
+ temp.index.name = 'cluster_id'
63
+ return temp.drop('policy_id', axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def extract_and_scale_reps(self, df, agg=None):
66
  extracted_df = self.extract_reps(df)
67
  if extracted_df.empty:
68
+ return extracted_df # Return empty if no representatives
 
 
 
 
69
 
70
  if agg and isinstance(agg, dict):
71
+ # mult should be a Series aligned with extracted_df's columns for element-wise multiplication after selection
72
+ # This part of the logic seems to intend to scale rows based on policy_count for 'sum' aggs
73
+ # and leave 'mean' aggs as is (to be weighted later).
74
+ # The original code created a DataFrame `mult` then did .mul(mult).
75
+ # A more direct approach for scaling rows:
76
+ scaled_df = extracted_df.copy()
77
  for c in extracted_df.columns:
78
+ if agg.get(c, 'sum') == 'sum': # Default to 'sum' if column not in agg
79
+ scaled_df[c] = extracted_df[c].mul(self.policy_count, axis=0)
80
+ # else (it's 'mean'), do not scale by policy_count here.
81
+ return scaled_df
82
+ else: # Default: scale all columns by policy_count (as if for sum)
83
+ return extracted_df.mul(self.policy_count, axis=0)
 
 
84
 
85
  def compare(self, df, agg=None):
86
  source = self.agg_by_cluster(df, agg)
87
+ target = self.extract_and_scale_reps(df, agg) # This target needs to be aggregated like source
88
 
89
+ # The target from extract_and_scale_reps is already scaled per cluster for 'sum' ops.
90
+ # For 'mean' ops, it's the representative value.
91
+ # We need to sum up the 'sum' columns and calculate weighted average for 'mean' columns.
92
+ if agg and isinstance(agg, dict):
93
+ agg_ops_for_target = {}
 
 
 
 
94
  for col, method in agg.items():
95
+ if method == 'sum':
96
+ agg_ops_for_target[col] = 'sum'
97
+ elif method == 'mean':
98
+ # For mean, we need sum(val*count)/sum(count).
99
+ # extract_and_scale_reps DID NOT scale mean columns by policy_count.
100
+ # So, target[col] has rep values. We need to weight them.
101
+ # This is better handled in compare_total. Here, target is per-cluster.
102
+ # This function compares per-cluster values BEFORE final aggregation.
103
+ # So target should represent aggregated values per cluster.
104
+ pass # 'sum' columns are scaled, 'mean' columns are rep values
105
+ else: # all sum
106
+ pass # target is already scaled by policy_count, so it's the sum per cluster
107
+
108
+ # This function is for per-cluster comparison, not total.
109
+ # The 'target' from extract_and_scale_reps already has the representative values scaled by policy_count for sum-like aggregations.
110
+ # If a column is meant for 'mean', it's just the representative value.
111
+ # This 'compare' function might be misinterpreting 'target' if 'agg' has 'mean'.
112
+ # The original notebook's compare function:
113
+ # source = self.agg_by_cluster(df, agg) # Actual sums/means per cluster
114
+ # target = self.extract_and_scale_reps(df, agg) # Rep values, scaled by count if 'sum', unscaled if 'mean'
115
+ # This structure implies 'target' might not be directly comparable if 'mean' is involved without further processing.
116
+ # However, the scatter plots it generates plot these per-cluster values.
117
+ # For 'sum' variables, target is an estimate of the cluster total.
118
+ # For 'mean' variables, target is the rep's value (estimate of cluster mean).
119
+
120
+ return pd.DataFrame({'actual': source.stack(), 'estimate': target.stack()})
121
 
122
 
123
  def compare_total(self, df, agg=None):
124
+ """Aggregate df by columns and compare actual vs estimate totals."""
125
  if df.empty:
126
  return pd.DataFrame(columns=['actual', 'estimate', 'error'])
127
 
128
+ # Determine aggregation operations for each column
129
  op_for_actual = {}
130
  if isinstance(agg, dict):
131
  for c in df.columns:
132
+ op_for_actual[c] = agg.get(c, 'sum') # Default to 'sum' if not in agg
133
+ else: # agg is None or not a dict, apply sum to all
134
  for c in df.columns:
135
  if pd.api.types.is_numeric_dtype(df[c]):
136
  op_for_actual[c] = 'sum'
137
+ # else: non-numeric columns will be ignored by df.agg if op not specified
138
+
139
+ actual = df.agg(op_for_actual)
140
+ actual = actual.dropna() # Remove non-numeric results if any
141
 
142
+ # Calculate estimate
143
+ reps_values = self.extract_reps(df) # Get raw representative values (one per cluster)
144
+ if reps_values.empty: # No representatives found
145
+ estimate = pd.Series(index=actual.index, dtype=float) # Empty or NaN series
146
  else:
147
  estimate_values = {}
148
+ for col_name in actual.index: # Iterate over columns that had a valid actual aggregation
149
+ col_op = op_for_actual.get(col_name, 'sum')
150
+
151
+ if col_name not in reps_values.columns: # Should not happen if df columns match
 
 
152
  estimate_values[col_name] = np.nan
153
  continue
154
+
155
  rep_col_values = reps_values[col_name]
156
+
157
  if col_op == 'sum':
158
+ # Estimate for sum is sum of (representative_value * policy_count_for_its_cluster)
159
+ estimate_values[col_name] = (rep_col_values * self.policy_count).sum()
160
  elif col_op == 'mean':
161
+ # Estimate for mean is weighted average: sum(rep_value * policy_count) / sum(policy_count)
162
+ weighted_sum = (rep_col_values * self.policy_count).sum()
163
+ total_weight = self.policy_count.sum()
164
  estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan
165
+ else: # Should not happen given op_for_actual logic
166
  estimate_values[col_name] = np.nan
167
+
168
+ estimate = pd.Series(estimate_values, index=actual.index) # Align with actual's index
169
+
170
+ # Calculate error
171
+ # Align actual and estimate to ensure they cover the same items for error calculation
172
  actual_aligned, estimate_aligned = actual.align(estimate, join='inner')
173
+
174
  error = pd.Series(index=actual_aligned.index, dtype=float)
175
+
176
+ # Valid division where actual is not zero and not NaN
177
  valid_mask = (actual_aligned != 0) & (~actual_aligned.isna())
178
  error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1
179
+
180
+ # Where actual is zero (and not NaN)
181
  actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna())
182
+ # If estimate is also zero, error is 0
183
  error[actual_zero_mask & (estimate_aligned == 0)] = 0
184
+ # If estimate is non-zero and actual is zero, error is effectively infinite
185
+ error[actual_zero_mask & (estimate_aligned != 0)] = np.inf
186
+
187
+ # Replace any infinities with NaN for cleaner results (e.g., for .mean())
188
  error = error.replace([np.inf, -np.inf], np.nan)
189
 
190
+ result_df = pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error})
191
+ return result_df
192
 
193
 
194
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
195
  if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0:
196
+ fig, ax = plt.subplots()
197
+ ax.text(0.5, 0.5, "No data for cashflow comparison plot.", ha='center', va='center')
198
  buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
199
 
200
  num_plots = len(cfs_list)
201
+ cols = 2
202
  rows = (num_plots + cols - 1) // cols
203
 
204
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
205
  axes = axes.flatten()
206
+
207
  plot_made = False
 
208
  for i, (df_cf, title) in enumerate(zip(cfs_list, titles)):
209
  if i < len(axes):
 
210
  if df_cf is None or df_cf.empty:
211
+ axes[i].text(0.5,0.5, f"No data for {title}", ha='center', va='center')
212
+ axes[i].set_title(title)
213
  continue
214
+ comparison = cluster_obj.compare_total(df_cf) # Default is sum for all columns
215
+ if not comparison.empty and 'actual' in comparison and 'estimate' in comparison:
216
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
217
+ axes[i].set_xlabel('Time')
218
+ axes[i].set_ylabel('Value')
219
+ plot_made = True
220
+ else:
221
+ axes[i].text(0.5,0.5, f"Could not generate comparison for {title}", ha='center', va='center')
222
+ axes[i].set_title(title)
 
 
223
 
224
+ for j in range(i + 1, len(axes)): # Hide unused subplots
225
+ fig.delaxes(axes[j])
226
+
227
+ if not plot_made: # If no plots were actually made (e.g. all data was empty)
228
+ plt.close(fig) # Close the figure
229
+ fig, ax = plt.subplots() # Create a new one for the message
230
+ ax.text(0.5, 0.5, "Insufficient data for any cashflow plots.", ha='center', va='center')
231
+
232
 
233
+ plt.tight_layout()
234
+ buf = io.BytesIO()
235
+ plt.savefig(buf, format='png', dpi=100)
236
+ buf.seek(0)
237
+ img = Image.open(buf)
238
+ plt.close(fig)
239
+ return img
240
 
241
  def plot_scatter_comparison(df_compare_output, title):
242
  if df_compare_output is None or df_compare_output.empty:
243
+ fig, ax = plt.subplots(figsize=(10,6)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title)
244
  buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
245
 
246
+ fig, ax = plt.subplots(figsize=(10, 6))
247
 
248
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
249
+ # This case indicates df_compare_output is not from cluster_obj.compare() as expected
250
  ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6)
251
  else:
252
  unique_levels = df_compare_output.index.get_level_values(1).unique()
253
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
254
+
255
+ for item_level, color_val in zip(unique_levels, colors):
256
+ subset = df_compare_output.xs(item_level, level=1)
257
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
258
+ if len(unique_levels) > 1 and len(unique_levels) <=10:
259
+ ax.legend(title=df_compare_output.index.names[1])
260
+
261
+ ax.set_xlabel('Actual')
262
+ ax.set_ylabel('Estimate')
263
+ ax.set_title(title)
264
+ ax.grid(True)
 
 
 
265
 
266
  try:
267
+ current_xlim = ax.get_xlim()
268
+ current_ylim = ax.get_ylim()
269
+ lims = [
270
+ np.nanmin([current_xlim, current_ylim]),
271
+ np.nanmax([current_xlim, current_ylim]),
272
+ ]
273
+ if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]):
274
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
275
+ ax.set_xlim(lims)
276
+ ax.set_ylim(lims)
277
+ except Exception: # Catch errors if lims are problematic (e.g. all NaNs)
278
+ pass
279
 
280
+ buf = io.BytesIO()
281
+ plt.savefig(buf, format='png', dpi=100)
282
+ buf.seek(0)
283
+ img = Image.open(buf)
284
+ plt.close(fig)
285
+ return img
286
+
287
 
288
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
289
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
 
294
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
295
 
296
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
297
+ required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
298
+ missing_policy_cols = [col for col in required_cols if col not in pol_data_full.columns]
299
  if missing_policy_cols:
300
+ gr.Warning(f"Policy data is missing required columns: {', '.join(missing_policy_cols)}. Analysis may be affected.")
301
+ pol_data = pol_data_full # Use what's available
302
  else:
303
+ pol_data = pol_data_full[required_cols]
304
 
305
  pvs = pd.read_excel(pv_base_path, index_col=0)
306
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
 
308
 
309
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
310
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
311
+
312
  mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
313
 
314
+ # --- 1. Cashflow Calibration ---
315
+ gr.Info("Starting Cashflow Calibration...")
316
+ if cfs.empty: gr.Warning("Base cashflow data is empty for Cashflow Calibration.")
317
+ cluster_cfs = Clusters(cfs)
318
+ results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
319
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg)
320
+ results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
321
+ results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
322
+ results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
323
+ results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
324
+ results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)')
325
+ gr.Info("Cashflow Calibration Done.")
326
+
327
+ # --- 2. Policy Attribute Calibration ---
328
+ gr.Info("Starting Policy Attribute Calibration...")
329
+ if pol_data.empty :
330
+ gr.Warning("Policy data is empty. Skipping Policy Attribute Calibration.")
331
+ loc_vars_attrs = pd.DataFrame() # Empty dataframe
332
+ else:
333
+ pol_data_min = pol_data.min()
334
+ pol_data_range = pol_data.max() - pol_data_min
335
+ # Avoid division by zero if a column has no variance (all values are the same)
336
+ if (pol_data_range == 0).any():
337
+ gr.Warning("Some policy attributes have no variance (all values are the same). Standardization might be affected.")
338
+ # For columns with zero range, standardized value becomes 0 or NaN depending on pandas version.
339
+ # A common approach is to set them to 0 or handle them separately.
340
+ # Here, we proceed, but pandas might produce NaNs if (val - min) / 0 occurs.
341
+ # Let's ensure range is not zero for division:
342
+ pol_data_range[pol_data_range == 0] = 1 # Avoid division by zero, effectively making constant columns 0 after (x-min)/1
343
+ loc_vars_attrs = (pol_data - pol_data_min) / pol_data_range
344
+ loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle any NaNs from perfect constant columns
345
+
346
+ if not loc_vars_attrs.empty:
347
+ cluster_attrs = Clusters(loc_vars_attrs)
348
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
349
  results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg)
350
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
351
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
352
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)')
353
+ else:
354
+ results.update({k: pd.DataFrame() for k in ['attr_total_cf_base', 'attr_policy_attrs_total', 'attr_total_pv_base']})
355
+ results.update({k: None for k in ['attr_cashflow_plot', 'attr_scatter_cashflows_base']})
356
+ gr.Info("Policy Attribute Calibration Done.")
357
+
358
+ # --- 3. Present Value Calibration ---
359
+ gr.Info("Starting Present Value Calibration...")
360
+ if pvs.empty: gr.Warning("Base Present Value data is empty for PV Calibration.")
361
+ cluster_pvs = Clusters(pvs)
362
+ results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
363
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg)
364
+ results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
365
+ results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
366
+ results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
367
+ results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
368
+ results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
369
+ gr.Info("Present Value Calibration Done.")
370
 
371
+ # --- Summary Comparison Plot Data ---
372
  gr.Info("Generating Summary Plot...")
373
  error_data = {}
374
+ pv_col_name = 'PV_NetCF' # Target column for summary
 
 
 
 
 
375
 
376
+ for calib_prefix, cluster_obj, calib_name_display in [
377
+ ('CF Calib.', cluster_cfs, "CF Calib."),
378
+ ('Attr Calib.', globals().get('cluster_attrs'), "Attr Calib."),
379
+ ('PV Calib.', cluster_pvs, "PV Calib.")]:
380
+
381
  current_calib_errors = []
382
+ if cluster_obj is None and calib_prefix == 'Attr Calib.': # Attr calib might be skipped
383
  current_calib_errors = [np.nan, np.nan, np.nan]
384
  else:
385
  for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]:
386
+ if pv_df_scenario.empty:
387
+ current_calib_errors.append(np.nan)
388
+ continue
389
+
390
  comp_total_df = cluster_obj.compare_total(pv_df_scenario)
391
+ if pv_col_name in comp_total_df.index:
392
+ error_val = comp_total_df.loc[pv_col_name, 'error']
393
+ elif not comp_total_df.empty and 'error' in comp_total_df.columns:
394
+ error_val = comp_total_df['error'].mean() # Fallback
395
+ if calib_prefix == 'CF Calib.' and pv_df_scenario is pvs: # Only warn once per type if fallback
396
+ gr.Warning(f"'{pv_col_name}' not found for summary plot. Using mean error of all PV columns instead for {calib_name_display}.")
397
+ else: # comp_total_df is empty or no 'error' column
398
+ error_val = np.nan
399
  current_calib_errors.append(abs(error_val))
400
  error_data[calib_name_display] = current_calib_errors
401
+
402
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
403
+
404
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
 
405
 
406
+ plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}'
407
+ if summary_df.isnull().all().all():
408
+ ax_summary.text(0.5, 0.5, f"Error data for summary is N/A.\nCheck input PV files for '{pv_col_name}' column and valid numeric data.",
409
  ha='center', va='center', transform=ax_summary.transAxes, wrap=True)
410
+ ax_summary.set_title(plot_title)
411
+ elif summary_df.empty:
412
+ ax_summary.text(0.5, 0.5, "No summary data to plot.", ha='center', va='center')
413
+ ax_summary.set_title(plot_title)
414
  else:
415
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
416
  ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)')
417
+ ax_summary.set_title(plot_title)
418
  ax_summary.tick_params(axis='x', rotation=0)
 
 
 
 
419
 
420
+ plt.tight_layout()
421
+ buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=100); buf_summary.seek(0)
422
+ results['summary_plot'] = Image.open(buf_summary)
423
+ plt.close(fig_summary)
424
+ gr.Info("All processing complete.")
 
 
 
 
 
425
  return results
426
 
427
+ except FileNotFoundError as e:
428
+ gr.Error(f"File not found: {e.filename}. Ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded correctly.")
429
+ return {"error": f"File not found: {e.filename}"}
430
+ except ValueError as e: # Catch specific errors like empty data for KMeans
431
+ gr.Error(f"Data validation error: {str(e)}")
432
+ return {"error": f"Data error: {str(e)}"}
433
+ except KeyError as e:
434
+ gr.Error(f"A required column is missing: {e}. Please check data formats, especially index columns and expected data columns like 'PV_NetCF'.")
435
+ return {"error": f"Missing column: {e}"}
436
  except Exception as e:
437
+ gr.Error(f"An unexpected error occurred during processing: {str(e)}")
438
+ import traceback
439
+ traceback.print_exc() # Print full traceback to console for debugging
440
+ return {"error": f"Processing error: {str(e)}"}
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ def create_interface():
444
+ with gr.Blocks(title="Cluster Model Points Analysis") as demo:
445
+ gr.Markdown("""
446
+ # Cluster Model Points Analysis
447
+ This application applies k-means cluster analysis to select representative model points from an insurance portfolio.
448
+ Upload your Excel files or use the example data to analyze results based on different calibration variable choices.
449
+ **Required Excel (.xlsx) Files:**
450
+ - Cashflows - Base Scenario
451
+ - Cashflows - Lapse Stress (+50%)
452
+ - Cashflows - Mortality Stress (+15%)
453
+ - Policy Data (must include 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth', and an index column for `policy_id`)
454
+ - Present Values - Base Scenario (ideally with a 'PV_NetCF' column and an index column for `policy_id`)
455
+ - Present Values - Lapse Stress (same structure as Base PV)
456
+ - Present Values - Mortality Stress (same structure as Base PV)
457
+ """)
458
+
459
  with gr.Row():
460
+ with gr.Column(scale=1):
461
  gr.Markdown("### 📂 Upload Files or Load Examples")
462
+ load_example_btn = gr.Button("Load Example Data", icon="💾")
463
  with gr.Row():
464
+ cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
465
+ cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
466
+ cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
467
  with gr.Row():
468
+ policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
469
+ pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
470
+ pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
471
  with gr.Row():
472
+ pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
473
+ span_dummy = gr.File(visible=False) # For layout balance if needed
474
+ span_dummy2 = gr.File(visible=False)
475
+
476
+
477
+ analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", scale=1)
478
 
479
  with gr.Tabs():
480
+ with gr.TabItem("📊 Summary"):
481
  summary_plot_output = gr.Image(label="Calibration Methods Comparison")
482
 
483
+ with gr.TabItem("💸 Cashflow Calibration"):
484
+ gr.Markdown("### Results: Using Annual Cashflows (Base) as Calibration Variables")
485
+ with gr.Row():
486
+ cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
487
+ cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
488
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
489
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
490
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
491
+ with gr.Row():
492
+ cf_pv_total_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
493
+ cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
494
+ cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
495
 
496
+ with gr.TabItem("👤 Policy Attribute Calibration"):
497
+ gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
498
+ with gr.Row():
499
+ attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
500
+ attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
501
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
502
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
503
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
504
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario", wrap=True)
505
+
506
+ with gr.TabItem("💰 Present Value Calibration"):
507
+ gr.Markdown("### Results: Using Present Values (Base) as Calibration Variables")
508
+ with gr.Row():
509
+ pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
510
+ pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
511
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
512
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter: Per-Cluster PVs (Base)")
513
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
514
  with gr.Row():
515
+ pv_total_pv_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
516
+ pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
517
+ pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
518
+
519
+ output_components = [
520
+ summary_plot_output,
521
+ cf_total_base_table_out, cf_policy_attrs_total_out, cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
522
+ cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
523
+ attr_total_cf_base_out, attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
524
+ pv_total_cf_base_out, pv_policy_attrs_total_out, pv_cashflow_plot_out, pv_scatter_pvs_base_out,
525
+ pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  ]
527
+
528
+ def handle_analysis_click(f1, f2, f3, f4, f5, f6, f7):
529
+ all_files_present = all(f is not None for f in [f1, f2, f3, f4, f5, f6, f7])
530
+ if not all_files_present:
531
+ gr.Warning("Not all files have been provided. Please upload all 7 files or load example data.")
532
+ return [None] * len(output_components) # Return Nones for all output components
533
 
534
+ # file objects (f1, etc.) from gr.File are TemporaryFileWrapper or string paths if loaded by example
535
  file_paths = []
536
+ for f_obj in [f1, f2, f3, f4, f5, f6, f7]:
537
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Uploaded file
538
+ file_paths.append(f_obj.name)
539
+ elif isinstance(f_obj, str): # Path from example load
540
+ file_paths.append(f_obj)
541
+ else: # Should not happen if files are present
542
+ gr.Error(f"Invalid file input: {f_obj}. Please re-upload or reload examples.")
543
+ return [None] * len(output_components)
544
 
545
  analysis_results = process_files(*file_paths)
546
+
547
+ if "error" in analysis_results: # Error handled and displayed by process_files
548
+ return [None] * len(output_components)
549
 
550
  # Map results to output components
551
+ return [
552
+ analysis_results.get('summary_plot'),
553
+ analysis_results.get('cf_total_base_table'), analysis_results.get('cf_policy_attrs_total'),
554
+ analysis_results.get('cf_cashflow_plot'), analysis_results.get('cf_scatter_cashflows_base'),
555
+ analysis_results.get('cf_pv_total_base'), analysis_results.get('cf_pv_total_lapse'), analysis_results.get('cf_pv_total_mort'),
556
+ analysis_results.get('attr_total_cf_base'), analysis_results.get('attr_policy_attrs_total'),
557
+ analysis_results.get('attr_cashflow_plot'), analysis_results.get('attr_scatter_cashflows_base'), analysis_results.get('attr_total_pv_base'),
558
+ analysis_results.get('pv_total_cf_base'), analysis_results.get('pv_policy_attrs_total'),
559
+ analysis_results.get('pv_cashflow_plot'), analysis_results.get('pv_scatter_pvs_base'),
560
+ analysis_results.get('pv_total_pv_base'), analysis_results.get('pv_total_pv_lapse'), analysis_results.get('pv_total_pv_mort')
561
+ ]
562
+
563
+ analyze_btn.click(
564
+ handle_analysis_click,
565
+ inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
566
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
567
+ outputs=output_components
568
+ )
569
 
570
+ input_file_components = [
571
+ cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
572
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input
573
+ ]
574
  def load_example_files_action():
575
+ missing_example_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
576
+ if missing_example_files:
577
+ gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_example_files)}. Please ensure they exist.")
578
+ return [None] * len(input_file_components)
579
+ gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.")
580
+ return [
581
+ EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
582
+ EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
583
+ EXAMPLE_FILES["pv_mort"]
584
+ ]
585
+ load_example_btn.click(load_example_files_action, inputs=[], outputs=input_file_components)
586
  return demo
587
 
588
  if __name__ == "__main__":
589
  if not os.path.exists(EXAMPLE_DATA_DIR):
590
+ try:
591
+ os.makedirs(EXAMPLE_DATA_DIR)
592
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
593
+ print(f"Expected files: {list(EXAMPLE_FILES.keys())}")
594
+ except OSError as e:
595
+ print(f"Error creating directory {EXAMPLE_DATA_DIR}: {e}. Please create it manually.")
596
 
597
+ print("Starting Gradio application...")
598
+ print(f"Note: Ensure your example Excel files are placed in the '{os.getcwd()}{os.sep}{EXAMPLE_DATA_DIR}' folder.")
599
+ print(f"Required policy data columns: 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' (and an index col).")
600
+ print(f"Recommended PV files column for summary: 'PV_NetCF' (and an index col).")
601
+
602
  demo_app = create_interface()
603
  demo_app.launch()