alidenewade commited on
Commit
e82ad24
·
verified ·
1 Parent(s): 46a2e7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -583
app.py CHANGED
@@ -2,608 +2,570 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
- import os
10
  from PIL import Image
11
 
12
  # Define the paths for example data
13
  EXAMPLE_DATA_DIR = "eg_data"
14
  EXAMPLE_FILES = {
15
- "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
- "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
- "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
- "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"),
19
- "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
- "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
- "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
22
  }
23
 
24
  class Clusters:
25
- def __init__(self, loc_vars):
26
- if loc_vars.empty:
27
- raise ValueError("Input data for KMeans (loc_vars) is empty.")
28
- if loc_vars.isnull().all().all():
29
- raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.")
30
-
31
- n_samples = len(loc_vars)
32
- n_clusters_to_use = min(1000, n_samples)
33
- if n_clusters_to_use == 0 :
34
- raise ValueError("Cannot determine n_clusters as no samples are available.")
35
-
36
- self.kmeans = KMeans(n_clusters=n_clusters_to_use, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
37
- closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
38
-
39
- rep_ids = pd.Series(data=(closest + 1))
40
- rep_ids.name = 'policy_id'
41
- rep_ids.index.name = 'cluster_id'
42
- self.rep_ids = rep_ids
43
-
44
- if n_samples > 0:
45
- self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * n_samples})).get('policy_count', pd.Series(dtype=int))
46
- if self.policy_count is None: # get can return None if key not present
47
- self.policy_count = pd.Series(dtype=int).rename_axis('cluster_id')
48
- else:
49
- self.policy_count = pd.Series(dtype=int).rename_axis('cluster_id')
50
-
51
-
52
- def agg_by_cluster(self, df, agg=None):
53
- temp = df.copy()
54
- if len(self.kmeans.labels_) != len(df):
55
- gr.Warning(f"Length mismatch in agg_by_cluster: kmeans.labels_ ({len(self.kmeans.labels_)}) vs df ({len(df)}).")
56
- # Attempt to proceed if df is shorter, otherwise this indicates a deeper issue
57
- if len(self.kmeans.labels_) < len(df):
58
- # Cannot assign labels if df is longer than available labels
59
- return pd.DataFrame() # Or raise error
60
- temp['cluster_id'] = self.kmeans.labels_[:len(df)]
61
- else:
62
- temp['cluster_id'] = self.kmeans.labels_
63
-
64
- temp = temp.set_index('cluster_id')
65
-
66
- agg_ops = {}
67
- if isinstance(agg, dict):
68
- agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns if pd.api.types.is_numeric_dtype(temp[c])}
69
- else:
70
- for col in temp.columns:
71
- if pd.api.types.is_numeric_dtype(temp[col]):
72
- agg_ops[col] = 'sum'
73
-
74
- if not agg_ops: # No numeric columns or no valid agg ops
75
- return pd.DataFrame(index=temp.index.unique()) # Return empty DF with cluster index
76
-
77
- return temp.groupby(temp.index).agg(agg_ops)
78
-
79
-
80
- def extract_reps(self, df):
81
- df_reset = df.reset_index()
82
- original_index_name = df.index.name if df.index.name else 'index'
83
-
84
- # Ensure 'policy_id' column exists for the merge operation
85
- if 'policy_id' not in df_reset.columns:
86
- if original_index_name in df_reset.columns and original_index_name != 'policy_id':
87
- df_reset = df_reset.rename(columns={original_index_name: 'policy_id'})
88
- elif original_index_name == 'policy_id': # Already named policy_id
89
- pass
90
- else: # No identifiable policy_id column from index
91
- gr.Warning(f"Could not find 'policy_id' from index '{original_index_name}' for merging in extract_reps. Trying to merge on index if rep_ids index matches.")
92
- # This path is risky; merge might fail if rep_ids index (cluster_id) doesn't match df_reset's current index
93
- # For safety, assuming policy_id must be present in rep_ids for merge.
94
- # If rep_ids uses 'policy_id' as data, then df_reset must have it as a column.
95
-
96
- if self.rep_ids.empty:
97
- gr.Warning("Representative IDs (rep_ids) are empty in extract_reps.")
98
- # Return an empty DataFrame with columns from df, but indexed by 'cluster_id' if possible
99
- # This is tricky as we don't know the cluster_ids without rep_ids.
100
- # Best to return an empty version of df's structure perhaps.
101
- return pd.DataFrame(columns=df.columns).rename_axis('cluster_id')
102
-
103
-
104
- temp = pd.merge(self.rep_ids, df_reset, how='left', on='policy_id')
105
- temp = temp.set_index('cluster_id') # rep_ids index is cluster_id
106
-
107
- if 'policy_id' in temp.columns: # Drop the policy_id column used for merging
108
- return temp.drop('policy_id', axis=1)
109
- return temp
110
-
111
-
112
- def extract_and_scale_reps(self, df, agg=None):
113
- extracted_df = self.extract_reps(df)
114
- if extracted_df.empty:
115
- return extracted_df
116
-
117
- scaled_df = extracted_df.copy()
118
- if self.policy_count.empty:
119
- gr.Warning("Policy count is empty in extract_and_scale_reps. Not scaling.")
120
- return scaled_df # Return unscaled if no policy counts
121
-
122
- policy_count_aligned = self.policy_count.reindex(scaled_df.index).fillna(0)
123
-
124
- if agg and isinstance(agg, dict):
125
- for c in extracted_df.columns:
126
- if pd.api.types.is_numeric_dtype(extracted_df[c]):
127
- if agg.get(c, 'sum') == 'sum': # Default to 'sum' for scaling
128
- scaled_df[c] = extracted_df[c].mul(policy_count_aligned, axis=0)
129
- else:
130
- for c in extracted_df.columns:
131
- if pd.api.types.is_numeric_dtype(extracted_df[c]):
132
- scaled_df[c] = extracted_df[c].mul(policy_count_aligned, axis=0)
133
- return scaled_df
134
-
135
- def compare(self, df, agg=None):
136
- source = self.agg_by_cluster(df, agg) # Aggregated actuals per cluster
137
-
138
- # Target: representative values, potentially scaled by policy_count for 'sum' type aggregations
139
- target_reps_raw = self.extract_reps(df) # Raw representative values per cluster
140
-
141
- if source.empty and target_reps_raw.empty:
142
- return pd.DataFrame(columns=['actual', 'estimate'])
143
- if source.empty: # Fill with NaNs if only source is empty
144
- source = pd.DataFrame(index=target_reps_raw.index, columns=target_reps_raw.columns)
145
- if target_reps_raw.empty: # Fill with NaNs if only target is empty
146
- target_reps_raw = pd.DataFrame(index=source.index, columns=source.columns)
147
-
148
-
149
- target_estimates_per_cluster = target_reps_raw.copy()
150
- if not self.policy_count.empty:
151
- policy_count_aligned = self.policy_count.reindex(target_reps_raw.index).fillna(0)
152
-
153
- if isinstance(agg, dict):
154
- for col, method in agg.items():
155
- if col in target_estimates_per_cluster.columns and method == 'sum':
156
- if pd.api.types.is_numeric_dtype(target_estimates_per_cluster[col]):
157
- target_estimates_per_cluster[col] = target_reps_raw[col].mul(policy_count_aligned, axis=0)
158
- elif not agg: # Default to sum if agg is None (original notebook behavior)
159
- for col in target_estimates_per_cluster.columns:
160
- if pd.api.types.is_numeric_dtype(target_estimates_per_cluster[col]):
161
- target_estimates_per_cluster[col] = target_reps_raw[col].mul(policy_count_aligned, axis=0)
162
- else: # No policy_count, target_estimates remain raw rep values
163
- gr.Warning("Policy_count is empty, compare() target estimates will be raw representative values.")
164
-
165
-
166
- # Align source and target_estimates_per_cluster before stacking
167
- aligned_source, aligned_target = source.align(target_estimates_per_cluster, join='outer', axis=0) # outer join on clusters
168
- aligned_source, aligned_target = aligned_source.align(aligned_target, join='outer', axis=1) # outer join on columns
169
-
170
- return pd.DataFrame({'actual': aligned_source.stack(dropna=False), 'estimate': aligned_target.stack(dropna=False)})
171
-
172
-
173
- def compare_total(self, df, agg=None):
174
- if df.empty:
175
- return pd.DataFrame(columns=['actual', 'estimate', 'error'])
176
-
177
- op_for_actual = {}
178
- numeric_cols_df = df.select_dtypes(include=np.number).columns
179
- if isinstance(agg, dict):
180
- for c in numeric_cols_df: op_for_actual[c] = agg.get(c, 'sum')
181
- else:
182
- for c in numeric_cols_df: op_for_actual[c] = 'sum'
183
-
184
- if not op_for_actual : # No numeric columns to aggregate
185
- return pd.DataFrame(columns=['actual', 'estimate', 'error'])
186
-
187
- actual = df.agg(op_for_actual).dropna()
188
- if actual.empty: # No results from aggregation
189
- return pd.DataFrame(columns=['actual', 'estimate', 'error'])
190
-
191
-
192
- reps_values = self.extract_reps(df)
193
- estimate_values = {}
194
-
195
- if reps_values.empty or self.policy_count.empty:
196
- estimate = pd.Series(index=actual.index, dtype=float).fillna(np.nan)
197
- else:
198
- policy_count_aligned = self.policy_count.reindex(reps_values.index).fillna(0)
199
- total_weight = policy_count_aligned.sum()
200
-
201
- for col_name in actual.index: # Iterate over columns that had a valid actual aggregation
202
- col_op = op_for_actual.get(col_name)
203
-
204
- if col_name not in reps_values.columns or not pd.api.types.is_numeric_dtype(reps_values[col_name]):
205
- estimate_values[col_name] = np.nan; continue
206
-
207
- rep_col_values = reps_values[col_name]
208
- if col_op == 'sum':
209
- estimate_values[col_name] = (rep_col_values * policy_count_aligned).sum()
210
- elif col_op == 'mean':
211
- if total_weight != 0:
212
- weighted_sum = (rep_col_values * policy_count_aligned).sum()
213
- estimate_values[col_name] = weighted_sum / total_weight
214
- else: estimate_values[col_name] = np.nan
215
- else: estimate_values[col_name] = np.nan # Should not happen
216
- estimate = pd.Series(estimate_values, index=actual.index) # Align with actual's index
217
-
218
- actual_aligned, estimate_aligned = actual.align(estimate, join='inner') # Only compare where both exist
219
- if actual_aligned.empty: # Nothing to compare
220
- return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': pd.Series(index=actual.index, dtype=float)})
221
-
222
-
223
- error = pd.Series(index=actual_aligned.index, dtype=float)
224
- valid_mask = (actual_aligned != 0) & (~actual_aligned.isna())
225
- error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1
226
-
227
- actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna())
228
- error[actual_zero_mask & (estimate_aligned == 0)] = 0.0
229
- error[actual_zero_mask & (estimate_aligned != 0) & (~estimate_aligned.isna())] = np.inf
230
-
231
- error = error.replace([np.inf, -np.inf], np.nan) # Convert inf to NaN for mean, etc.
232
-
233
- return pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error})
234
 
235
 
236
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
237
- if not cfs_list or cluster_obj is None or not titles or len(cfs_list) == 0: # cluster_obj can be None if init failed
238
- fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No data/cluster for cashflow plot.", ha='center', va='center')
239
- buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
240
-
241
- num_plots = len(cfs_list)
242
- cols = min(2, num_plots) if num_plots > 0 else 1
243
- rows = (num_plots + cols - 1) // cols if num_plots > 0 else 1
244
-
245
- fig, axes = plt.subplots(rows, cols, figsize=(7.5 * cols, 5 * rows), squeeze=False)
246
- axes = axes.flatten()
247
- plot_made = False
248
-
249
- for i, (df_cf, title) in enumerate(zip(cfs_list, titles)):
250
- if i < len(axes):
251
- ax_curr = axes[i]; ax_curr.set_title(title)
252
- if df_cf is None or df_cf.empty:
253
- ax_curr.text(0.5,0.5, f"No data for\n{title}", ha='center', va='center', wrap=True); continue
254
- try:
255
- comparison = cluster_obj.compare_total(df_cf)
256
- if not comparison.empty and 'actual' in comparison.columns and 'estimate' in comparison.columns:
257
- plot_df = comparison[['actual', 'estimate']].dropna(how='all')
258
- if not plot_df.empty:
259
- plot_df.plot(ax=ax_curr, grid=True)
260
- ax_curr.set_xlabel('Time Period'); ax_curr.set_ylabel('Cashflow Value')
261
- plot_made = True
262
- else: ax_curr.text(0.5,0.5, f"No comparable data\nfor {title}", ha='center', va='center', wrap=True)
263
- else: ax_curr.text(0.5,0.5, f"Comparison failed\nfor {title}", ha='center', va='center', wrap=True)
264
- except Exception as e: ax_curr.text(0.5,0.5, f"Error plotting {title}:\n{str(e)[:50]}...", ha='center', va='center', wrap=True)
265
-
266
- for j in range(num_plots, len(axes)): fig.delaxes(axes[j]) # Remove unused axes
267
- if not plot_made:
268
- plt.close(fig); fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No cashflow plots generated.", ha='center', va='center')
269
-
270
- plt.tight_layout(pad=2.0)
271
- buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=90); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
272
 
273
  def plot_scatter_comparison(df_compare_output, title):
274
- if df_compare_output is None or df_compare_output.empty:
275
- fig, ax = plt.subplots(figsize=(8,5)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title)
276
- buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
277
-
278
- fig, ax = plt.subplots(figsize=(8, 5))
279
- ax.set_title(title, fontsize='medium')
280
-
281
- if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
282
- ax.scatter(df_compare_output.get('actual', pd.Series(dtype=float)), df_compare_output.get('estimate', pd.Series(dtype=float)), s=9, alpha=0.6)
283
- else:
284
- try:
285
- unique_levels = df_compare_output.index.get_level_values(1).unique()
286
- if len(unique_levels) == 0 : ax.text(0.5, 0.5, "No data points for scatter.", ha='center', va='center')
287
- else:
288
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
289
- for item_level, color_val in zip(unique_levels, colors):
290
- try: subset = df_compare_output.xs(item_level, level=1)
291
- except KeyError: continue # Level not found, skip
292
- if not subset.empty: ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level))
293
- if len(unique_levels) > 1 and len(unique_levels) <=10: ax.legend(title=str(df_compare_output.index.names[1]), fontsize='small')
294
- except IndexError: # Problem with index levels
295
- ax.scatter(df_compare_output.get('actual', pd.Series(dtype=float)), df_compare_output.get('estimate', pd.Series(dtype=float)), s=9, alpha=0.6)
296
- gr.Warning("Could not process levels for scatter plot, showing raw data.")
297
-
298
-
299
- ax.set_xlabel('Actual Value'); ax.set_ylabel('Estimated Value')
300
- ax.grid(True, linestyle='--', alpha=0.7)
301
-
302
- try: # Draw identity line
303
- current_xlim = ax.get_xlim(); current_ylim = ax.get_ylim()
304
- if np.isfinite(current_xlim).all() and np.isfinite(current_ylim).all() and current_xlim[0] < current_xlim[1] and current_ylim[0] < current_ylim[1]:
305
- lims = [np.nanmin([current_xlim[0], current_ylim[0]]), np.nanmax([current_xlim[1], current_ylim[1]])]
306
- if lims[0] < lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]):
307
- ax.plot(lims, lims, 'r-', linewidth=1, alpha=0.8, dashes=(3,3)); ax.set_xlim(lims); ax.set_ylim(lims)
308
- except Exception: pass
309
-
310
- plt.tight_layout(pad=1.5)
311
- buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=90); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
314
- policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
315
- results = {}
316
- try:
317
- cfs = pd.read_excel(cashflow_base_path, index_col=0)
318
- cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
319
- cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
320
-
321
- pol_data_full = pd.read_excel(policy_data_path, index_col=0)
322
- required_policy_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
323
- missing_policy_cols = [col for col in required_policy_cols if col not in pol_data_full.columns]
324
- if missing_policy_cols: gr.Warning(f"Policy data missing: {', '.join(missing_policy_cols)}.")
325
- pol_data = pol_data_full[required_policy_cols] if not missing_policy_cols else pol_data_full
326
-
327
- pvs = pd.read_excel(pv_base_path, index_col=0)
328
- pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
329
- pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
330
-
331
- cfs_list = [cfs, cfs_lapse50, cfs_mort15]
332
- scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
333
- mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
334
-
335
- gr.Info("Processing calibrations...")
336
- cluster_cfs = cluster_attrs = cluster_pvs = None # Initialize
337
- if not cfs.empty: cluster_cfs = Clusters(cfs)
338
- else: gr.Warning("Base cashflow data is empty. CF Calib. might be affected or skipped.")
339
-
340
- if cluster_cfs:
341
- results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
342
- results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg)
343
- results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
344
- results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
345
- results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
346
- results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
347
- results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)')
348
-
349
- if not pol_data.empty:
350
- pol_data_min = pol_data.min(); pol_data_range = pol_data.max() - pol_data_min
351
- pol_data_range_safe = pol_data_range.copy()
352
- pol_data_range_safe[pol_data_range_safe == 0] = 1 # Avoid division by zero for constant columns
353
- loc_vars_attrs = ((pol_data - pol_data_min) / pol_data_range_safe).fillna(0) # Standardize
354
- if not loc_vars_attrs.empty: cluster_attrs = Clusters(loc_vars_attrs)
355
- else: gr.Warning("Policy data is empty. Attr Calib. skipped.")
356
-
357
- if cluster_attrs:
358
- results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
359
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg)
360
- results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
361
- results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
362
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)')
363
-
364
- if not pvs.empty: cluster_pvs = Clusters(pvs)
365
- else: gr.Warning("Base PV data is empty. PV Calib. might be affected or skipped.")
366
-
367
- if cluster_pvs:
368
- results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
369
- results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg)
370
- results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
371
- results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
372
- results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
373
- results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
374
- results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
375
-
376
- gr.Info("Generating Summary Plot...")
377
- error_data = {}
378
- pv_col_name = 'PV_NetCF'
379
- calibration_objects_for_summary = [
380
- ("CF Calib.", cluster_cfs), ("Attr Calib.", cluster_attrs), ("PV Calib.", cluster_pvs)
381
- ]
382
-
383
- for calib_name_display, cl_obj in calibration_objects_for_summary:
384
- current_errors = []
385
- if cl_obj is None: current_errors = [np.nan, np.nan, np.nan]
386
- else:
387
- for pv_df_scen in [pvs, pvs_lapse50, pvs_mort15]:
388
- err_val = np.nan
389
- if not pv_df_scen.empty:
390
- comp_df = cl_obj.compare_total(pv_df_scen)
391
- if not comp_df.empty:
392
- if pv_col_name in comp_df.index: err_val = comp_df.loc[pv_col_name, 'error']
393
- elif 'error' in comp_df.columns: err_val = comp_df['error'].mean() # Fallback
394
- current_errors.append(abs(err_val))
395
- error_data[calib_name_display] = current_errors
396
-
397
- summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']).round(4) # Round summary errors
398
- fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
399
- plot_title = f'Abs. Error in Total {pv_col_name} by Calibration Method'
400
-
401
- if summary_df.isnull().all().all() or summary_df.empty:
402
- ax_summary.text(0.5, 0.5, f"Summary N/A.\nCheck PV files for '{pv_col_name}'.", ha='center', va='center', wrap=True)
403
- else:
404
- summary_df.plot(kind='bar', ax=ax_summary, grid=True, width=0.8, legend=True)
405
- ax_summary.set_ylabel(f'Absolute Error (of {pv_col_name} or fallback mean)'); ax_summary.tick_params(axis='x', rotation=0)
406
- ax_summary.legend(title="Calibration Method")
407
- ax_summary.set_title(plot_title)
408
- plt.tight_layout(pad=1.5)
409
- buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=90); buf_summary.seek(0)
410
- results['summary_plot'] = Image.open(buf_summary); plt.close(fig_summary)
411
-
412
- for key, value in results.items():
413
- if isinstance(value, pd.DataFrame):
414
- try: results[key] = value.round(2)
415
- except (TypeError, AttributeError): pass # Ignore non-numeric data for rounding
416
- gr.Info("All processing complete. ✅")
417
- return results
418
-
419
- except FileNotFoundError as e: gr.Error(f"File not found: {e.filename}."); return {"error": str(e)}
420
- except ValueError as e: gr.Error(f"Data error: {str(e)}"); return {"error": str(e)}
421
- except KeyError as e: gr.Error(f"Missing column: {e}. Check data formats."); return {"error": str(e)}
422
- except Exception as e:
423
- gr.Error(f"Unexpected error: {str(e)}"); import traceback; traceback.print_exc()
424
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  def create_interface():
427
- with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Default()) as demo:
428
- gr.Markdown("## Cluster Model Points Analysis 📈")
429
- gr.Markdown("Applies k-means clustering for model point selection in insurance portfolios. Upload Excel files or use examples.")
430
- with gr.Accordion("📚 File Requirements & Instructions", open=False):
431
- gr.Markdown(
432
- """
433
- **Required Excel (.xlsx) Files (Index: `policy_id` for all):**
434
- 1. **Cashflows (Base, Lapse Stress, Mort Stress)**: Net annual cashflows (cols: time periods).
435
- 2. **Policy Data**: Attributes. Must include: `age_at_entry`, `policy_term`, `sum_assured`, `duration_mth`.
436
- 3. **Present Values (Base, Lapse Stress, Mort Stress)**: PVs of cashflow components. Ideally include `PV_NetCF`.
437
- All files must share a common `policy_id` (use `index_col=0` if it's the first column).
438
- """
439
- )
440
-
441
- with gr.Row():
442
- with gr.Column(scale=3):
443
- gr.Markdown("#### 📂 Upload Files or Load Examples")
444
- with gr.Row():
445
- cashflow_base_input = gr.File(label="CF Base", file_types=[".xlsx"], scale=1)
446
- cashflow_lapse_input = gr.File(label="CF Lapse Str.", file_types=[".xlsx"], scale=1)
447
- cashflow_mort_input = gr.File(label="CF Mort Str.", file_types=[".xlsx"], scale=1)
448
- with gr.Row():
449
- policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"], scale=1)
450
- pv_base_input = gr.File(label="PV Base", file_types=[".xlsx"], scale=1)
451
- pv_lapse_input = gr.File(label="PV Lapse Str.", file_types=[".xlsx"], scale=1)
452
- with gr.Row():
453
- pv_mort_input = gr.File(label="PV Mort Str.", file_types=[".xlsx"], scale=1)
454
- # Dummy invisible components for layout, if needed, or adjust column scales
455
- gr.HTML("", scale=1, visible=False)
456
- gr.HTML("", scale=1, visible=False)
457
-
458
-
459
- with gr.Column(scale=1, min_width=180): # Adjusted min_width
460
- gr.Markdown("ㅤ") # Spacer for alignment
461
- load_example_btn = gr.Button("Load Example Data", icon="💾", full_width=True)
462
- analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", full_width=True)
463
-
464
- with gr.Tabs():
465
- with gr.TabItem("📊 Summary", id="summary_tab"):
466
- summary_plot_output = gr.Image(label="Calibration Methods Comparison", type="pil") # Use type="pil"
467
-
468
- tab_items_data = [
469
- ("💸 CF Calib.", "cf", "Annual Cashflows (Base)"),
470
- ("👤 Attr Calib.", "attr", "Policy Attributes"),
471
- ("💰 PV Calib.", "pv", "Present Values (Base)")
472
- ]
473
-
474
- # Dynamically create output components and store them
475
- output_component_map = {"summary_plot_output": summary_plot_output}
476
-
477
- for tab_name, prefix, calib_vars_desc in tab_items_data:
478
- with gr.TabItem(tab_name, id=f"{prefix}_calib_tab"):
479
- gr.Markdown(f"#### Results: Using {calib_vars_desc} as Calibration Variables")
480
- with gr.Row():
481
- # Removed height parameter
482
- output_component_map[f"{prefix}_total_base_table_out"] = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
483
- output_component_map[f"{prefix}_policy_attrs_total_out"] = gr.Dataframe(label="Overall Comparison - Policy Attr.", wrap=True)
484
-
485
- output_component_map[f"{prefix}_cashflow_plot_out"] = gr.Image(label="Cashflow Value Comparisons", type="pil")
486
-
487
- scatter_label = "Scatter: Per-Cluster PVs (Base)" if prefix == "pv" else "Scatter: Per-Cluster CFs (Base)"
488
- output_component_map[f"{prefix}_scatter_display_out"] = gr.Image(label=scatter_label, type="pil")
489
-
490
- with gr.Accordion("Present Value Comparisons (Totals)", open=False):
491
- with gr.Row():
492
- # Removed height parameter
493
- output_component_map[f"{prefix}_pv_total_base_out"] = gr.Dataframe(label="PVs - Base", wrap=True)
494
- if prefix != "attr":
495
- output_component_map[f"{prefix}_pv_total_lapse_out"] = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
496
- output_component_map[f"{prefix}_pv_total_mort_out"] = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
497
-
498
- # Define the list of all output components in the correct order for the click handler
499
- ordered_output_keys = [
500
- 'summary_plot_output',
501
- 'cf_total_base_table_out', 'cf_policy_attrs_total_out', 'cf_cashflow_plot_out', 'cf_scatter_display_out',
502
- 'cf_pv_total_base_out', 'cf_pv_total_lapse_out', 'cf_pv_total_mort_out',
503
- 'attr_total_base_table_out', 'attr_policy_attrs_total_out', 'attr_cashflow_plot_out', 'attr_scatter_display_out',
504
- 'attr_pv_total_base_out',
505
- 'pv_total_base_table_out', 'pv_policy_attrs_total_out', 'pv_cashflow_plot_out', 'pv_scatter_display_out',
506
- 'pv_total_pv_base_out', 'pv_pv_total_lapse_out', 'pv_pv_total_mort_out'
507
- ]
508
- # Filter out keys that might not be created if a tab's structure changes (e.g., attr_pv_total_lapse)
509
- final_output_components = [output_component_map[k] for k in ordered_output_keys if k in output_component_map]
510
-
511
- input_file_components = [
512
- cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
513
- policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input
514
- ]
515
-
516
- def handle_analysis_click(*files_input):
517
- if not all(f is not None for f in files_input):
518
- gr.Warning("Not all files provided. Please upload/load all 7 files.")
519
- return [None] * len(final_output_components)
520
-
521
- file_paths = []
522
- for f_obj in files_input:
523
- if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): file_paths.append(f_obj.name)
524
- elif isinstance(f_obj, str): file_paths.append(f_obj)
525
- else: gr.Error(f"Invalid file input: {f_obj}."); return [None] * len(final_output_components)
526
-
527
- analysis_results = process_files(*file_paths)
528
- if "error" in analysis_results and analysis_results["error"]: # check if error is not None or empty
529
- return [None] * len(final_output_components)
530
-
531
- # Map results to output components based on the ordered_output_keys
532
- output_values = []
533
- # Keys used in process_files for results dict:
534
- # summary_plot
535
- # cf_total_base_table, cf_policy_attrs_total, cf_pv_total_base, cf_pv_total_lapse, cf_pv_total_mort, cf_cashflow_plot, cf_scatter_cashflows_base
536
- # attr_total_cf_base, attr_policy_attrs_total, attr_total_pv_base, attr_cashflow_plot, attr_scatter_cashflows_base
537
- # pv_total_cf_base, pv_policy_attrs_total, pv_total_pv_base, pv_total_pv_lapse, pv_total_pv_mort, pv_cashflow_plot, pv_scatter_pvs_base
538
-
539
- key_map = { # Maps UI component key stem to result key stem
540
- 'total_base_table_out': 'total_base_table',
541
- 'policy_attrs_total_out': 'policy_attrs_total',
542
- 'cashflow_plot_out': 'cashflow_plot',
543
- 'scatter_display_out': lambda p: f'scatter_{"pvs" if p == "pv" else "cashflows"}_base', # Special handling for scatter key
544
- 'pv_total_base_out': 'pv_total_base',
545
- 'pv_total_lapse_out': 'pv_total_lapse',
546
- 'pv_total_mort_out': 'pv_total_mort'
547
- }
548
-
549
- for ui_key in ordered_output_keys:
550
- if ui_key == "summary_plot_output":
551
- output_values.append(analysis_results.get('summary_plot'))
552
- continue
553
-
554
- # Deconstruct ui_key: e.g., "cf_total_base_table_out" -> prefix="cf", stem="total_base_table_out"
555
- parts = ui_key.split('_', 1)
556
- prefix = parts[0]
557
- stem_ui = parts[1]
558
-
559
- result_stem_mapper = key_map.get(stem_ui)
560
- if callable(result_stem_mapper): # For scatter plot key
561
- result_key_stem = result_stem_mapper(prefix)
562
- else:
563
- result_key_stem = result_stem_mapper
564
-
565
- if result_key_stem:
566
- result_data_key = f"{prefix}_{result_key_stem}"
567
- output_values.append(analysis_results.get(result_data_key))
568
- else: # Should not happen if ordered_output_keys and key_map are correct
569
- output_values.append(None)
570
- gr.Debug(f"No mapping found for UI key {ui_key}")
571
-
572
- return output_values
573
-
574
-
575
- analyze_btn.click(handle_analysis_click, inputs=input_file_components, outputs=final_output_components)
576
-
577
- def load_example_files_action():
578
- # Check if all example files exist
579
- # Ensure EXAMPLE_FILES dictionary keys match what's expected for list(EXAMPLE_FILES.values()) order
580
- expected_order = ["cashflow_base", "cashflow_lapse", "cashflow_mort", "policy_data", "pv_base", "pv_lapse", "pv_mort"]
581
- example_file_paths = []
582
- missing_files_list = []
583
-
584
- for key in expected_order:
585
- f_path = EXAMPLE_FILES.get(key)
586
- if f_path and os.path.exists(f_path):
587
- example_file_paths.append(f_path)
588
- else:
589
- missing_files_list.append(f_path or f"'{key}' not configured")
590
-
591
- if missing_files_list:
592
- gr.Error(f"Missing example data files: {', '.join(missing_files_list)}. Please ensure they exist in '{EXAMPLE_DATA_DIR}'.")
593
- return [None] * len(input_file_components)
594
-
595
- gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.")
596
- return example_file_paths
597
-
598
- load_example_btn.click(load_example_files_action, inputs=None, outputs=input_file_components) # No inputs for this button
599
-
600
- return demo
601
 
602
  if __name__ == "__main__":
603
- if not os.path.exists(EXAMPLE_DATA_DIR):
604
- try: os.makedirs(EXAMPLE_DATA_DIR); print(f"Created '{EXAMPLE_DATA_DIR}'. Place example Excel files there.")
605
- except OSError as e: print(f"Error creating {EXAMPLE_DATA_DIR}: {e}. Please create manually.")
606
-
607
- print(f"Starting Gradio application... Ensure example files are in '{os.path.abspath(EXAMPLE_DATA_DIR)}'")
608
- demo_app = create_interface()
609
- demo_app.launch()
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
9
+ import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
13
  EXAMPLE_DATA_DIR = "eg_data"
14
  EXAMPLE_FILES = {
15
+     "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
+     "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
+     "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
+     "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
19
+     "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
+     "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
+     "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
22
  }
23
 
24
  class Clusters:
25
+     def __init__(self, loc_vars):
26
+         self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
+         closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
+        
29
+         rep_ids = pd.Series(data=(closest+1))  # 0-based to 1-based indexes
30
+         rep_ids.name = 'policy_id'
31
+         rep_ids.index.name = 'cluster_id'
32
+         self.rep_ids = rep_ids
33
+        
34
+         self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
+
36
+     def agg_by_cluster(self, df, agg=None):
37
+         """Aggregate columns by cluster"""
38
+         temp = df.copy()
39
+         temp['cluster_id'] = self.kmeans.labels_
40
+         temp = temp.set_index('cluster_id')
41
+         agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
42
+         return temp.groupby(temp.index).agg(agg)
43
+
44
+     def extract_reps(self, df):
45
+         """Extract the rows of representative policies"""
46
+         temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
+         temp.index.name = 'cluster_id'
48
+         return temp.drop('policy_id', axis=1)
49
+
50
+     def extract_and_scale_reps(self, df, agg=None):
51
+         """Extract and scale the rows of representative policies"""
52
+         if agg:
53
+             cols = df.columns
54
+             mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+             # Ensure mult has same index as extract_reps(df) for proper alignment
56
+             extracted_df = self.extract_reps(df)
57
+             mult.index = extracted_df.index
58
+             return extracted_df.mul(mult)
59
+         else:
60
+             return self.extract_reps(df).mul(self.policy_count, axis=0)
61
+
62
+     def compare(self, df, agg=None):
63
+         """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
+         source = self.agg_by_cluster(df, agg)
65
+         target = self.extract_and_scale_reps(df, agg)
66
+         return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
67
+
68
+     def compare_total(self, df, agg=None):
69
+         """Aggregate df by columns"""
70
+         if agg:
71
+             # cols = df.columns # Not used
72
+             op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
73
+             actual = df.agg(op)
74
+            
75
+             # For estimate, ensure aggregation ops are correctly applied *after* scaling
76
+             scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
77
+            
78
+             # Corrected aggregation for estimate when 'mean' is involved
79
+             estimate_agg_ops = {}
80
+             for col_name, agg_type in op.items():
81
+                 if agg_type == 'mean':
82
+                     # Weighted average for mean columns
83
+                     estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
84
+                 else: # 'sum'
85
+                     estimate_agg_ops[col_name] = 'sum'
86
+            
87
+             # Need to handle the case where extract_and_scale_reps already applied scaling for sum
88
+             # The logic in extract_and_scale_reps is:
89
+             # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
90
+             # This means 'mean' columns are NOT multiplied by policy_count initially.
91
+            
92
+             # Let's re-think the estimate aggregation for 'mean'
93
+             estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
94
+            
95
+             final_estimate_ops = {}
96
+             for col, method in op.items():
97
+                 if method == 'mean':
98
+                     # For mean, we need the sum of (value * policy_count) / sum(policy_count)
99
+                     # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
100
+                     # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
101
+                     # sum them up, and divide by total policy_count.
102
+                     # However, the current extract_and_scale_reps scales 'mean' columns by 1.
103
+                     # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
104
+
105
+                     # Let's try to be more direct:
106
+                     # Get the representative policies (unscaled for mean columns)
107
+                     reps_unscaled_for_mean = self.extract_reps(df)
108
+                     estimate_values = {}
109
+                     for c in df.columns:
110
+                         if op[c] == 'sum':
111
+                            estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
112
+                         elif op[c] == 'mean':
113
+                            weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
114
+                            total_weight = self.policy_count.sum()
115
+                            estimate_values[c] = weighted_sum / total_weight if total_weight else 0
116
+                     estimate = pd.Series(estimate_values)
117
+
118
+                 else: # original 'sum' logic for all columns
119
+                     final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
120
+                     estimate = estimate_scaled.agg(final_estimate_ops)
121
+
122
+
123
+         else: # Original logic if no agg is specified (all sum)
124
+             actual = df.sum()
125
+             estimate = self.extract_and_scale_reps(df).sum()
126
+        
127
+         return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
131
+     """Create cashflow comparison plots"""
132
+     if not cfs_list or not cluster_obj or not titles:
133
+         return None # Or a placeholder image
134
+     num_plots = len(cfs_list)
135
+     if num_plots == 0:
136
+         return None
137
+
138
+     # Determine subplot layout (e.g., 2x2 or adapt)
139
+     cols = 2
140
+     rows = (num_plots + cols - 1) // cols
141
+    
142
+     fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
143
+     axes = axes.flatten()
144
+    
145
+     for i, (df, title) in enumerate(zip(cfs_list, titles)):
146
+         if i < len(axes):
147
+             comparison = cluster_obj.compare_total(df)
148
+             comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
149
+             axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
150
+             axes[i].set_ylabel('Value')
151
+    
152
+     # Hide any unused subplots
153
+     for j in range(i + 1, len(axes)):
154
+         fig.delaxes(axes[j])
155
+        
156
+     plt.tight_layout()
157
+     buf = io.BytesIO()
158
+     plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
159
+     buf.seek(0)
160
+     img = Image.open(buf)
161
+     plt.close(fig) # Ensure figure is closed
162
+     return img
 
 
 
163
 
164
  def plot_scatter_comparison(df_compare_output, title):
165
+     """Create scatter plot comparison from compare() output"""
166
+     if df_compare_output is None or df_compare_output.empty:
167
+         # Create a blank plot with a message
168
+         fig, ax = plt.subplots(figsize=(12, 8))
169
+         ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
170
+         ax.set_title(title)
171
+         buf = io.BytesIO()
172
+         plt.savefig(buf, format='png', dpi=100)
173
+         buf.seek(0)
174
+         img = Image.open(buf)
175
+         plt.close(fig)
176
+         return img
177
+
178
+     fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
179
+    
180
+     if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
181
+          gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
182
+          ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
183
+     else:
184
+         unique_levels = df_compare_output.index.get_level_values(1).unique()
185
+         colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
186
+        
187
+         for item_level, color_val in zip(unique_levels, colors):
188
+             subset = df_compare_output.xs(item_level, level=1)
189
+             ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
190
+         if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
191
+             ax.legend(title=df_compare_output.index.names[1])
192
+
193
+
194
+     ax.set_xlabel('Actual')
195
+     ax.set_ylabel('Estimate')
196
+     ax.set_title(title)
197
+     ax.grid(True)
198
+    
199
+     # Draw identity line
200
+     lims = [
201
+         np.min([ax.get_xlim(), ax.get_ylim()]),
202
+         np.max([ax.get_xlim(), ax.get_ylim()]),
203
+     ]
204
+     if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
205
+       ax.plot(lims, lims, 'r-', linewidth=0.5)
206
+       ax.set_xlim(lims)
207
+       ax.set_ylim(lims)
208
+    
209
+     buf = io.BytesIO()
210
+     plt.savefig(buf, format='png', dpi=100)
211
+     buf.seek(0)
212
+     img = Image.open(buf)
213
+     plt.close(fig)
214
+     return img
215
+
216
 
217
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
218
+                   policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
219
+     """Main processing function - now accepts file paths"""
220
+     try:
221
+         # Read uploaded files using paths
222
+         cfs = pd.read_excel(cashflow_base_path, index_col=0)
223
+         cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
224
+         cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
225
+        
226
+         pol_data_full = pd.read_excel(policy_data_path, index_col=0)
227
+         # Ensure the correct columns are selected for pol_data
228
+         required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
229
+         if all(col in pol_data_full.columns for col in required_cols):
230
+             pol_data = pol_data_full[required_cols]
231
+         else:
232
+             # Fallback or error if columns are missing. For now, try to use as is or a subset.
233
+             gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
234
+             pol_data = pol_data_full
235
+
236
+
237
+         pvs = pd.read_excel(pv_base_path, index_col=0)
238
+         pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
239
+         pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
240
+        
241
+         cfs_list = [cfs, cfs_lapse50, cfs_mort15]
242
+         # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
243
+         scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
244
+        
245
+         results = {}
246
+        
247
+         mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
248
+
249
+         # --- 1. Cashflow Calibration ---
250
+         cluster_cfs = Clusters(cfs)
251
+        
252
+         results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
253
+         # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
254
+         # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
255
+
256
+         results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
257
+        
258
+         results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
259
+         results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
260
+         results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
261
+        
262
+         results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
263
+         results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
264
+         # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
265
+         # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
266
+
267
+         # --- 2. Policy Attribute Calibration ---
268
+         # Standardize policy attributes
269
+         if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
270
+              loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
271
+         else:
272
+             gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
273
+             loc_vars_attrs = pol_data # or handle as an error/skip
274
+        
275
+         if not loc_vars_attrs.empty:
276
+             cluster_attrs = Clusters(loc_vars_attrs)
277
+             results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
278
+             results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
279
+             results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
280
+             results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
281
+             results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
282
+             # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
283
+
284
+         else: # Fill with None if skipped
285
+             results['attr_total_cf_base'] = pd.DataFrame()
286
+             results['attr_policy_attrs_total'] = pd.DataFrame()
287
+             results['attr_total_pv_base'] = pd.DataFrame()
288
+             results['attr_cashflow_plot'] = None
289
+             results['attr_scatter_cashflows_base'] = None
290
+
291
+
292
+         # --- 3. Present Value Calibration ---
293
+         cluster_pvs = Clusters(pvs)
294
+        
295
+         results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
296
+         results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
297
+        
298
+         results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
299
+         results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
300
+         results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
301
+        
302
+         results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
303
+         results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
304
+         # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
305
+
306
+
307
+         # --- Summary Comparison Plot Data ---
308
+         # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
309
+         # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
310
+         # For simplicity, using mean of the 'error' column from compare_total for key metrics
311
+        
312
+         error_data = {}
313
+        
314
+         # Cashflow Calibration Errors
315
+         if 'PV_NetCF' in pvs.columns:
316
+             err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
317
+             err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
318
+             err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
319
+             error_data['CF Calib. (PV NetCF)'] = [
320
+                 abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
321
+             ]
322
+         else: # Fallback if PV_NetCF is not present
323
+             error_data['CF Calib. (PV NetCF)'] = [
324
+                 abs(cluster_cfs.compare_total(pvs)['error'].mean()),
325
+                 abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
326
+                 abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
327
+             ]
328
+
329
+
330
+         # Policy Attribute Calibration Errors
331
+         if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
332
+             err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
333
+             err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
334
+             err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
335
+             error_data['Attr Calib. (PV NetCF)'] = [
336
+                 abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
337
+             ]
338
+         else:
339
+              error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
340
+
341
+
342
+         # Present Value Calibration Errors
343
+         if 'PV_NetCF' in pvs.columns:
344
+             err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
345
+             err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
346
+             err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
347
+             error_data['PV Calib. (PV NetCF)'] = [
348
+                 abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
349
+             ]
350
+         else:
351
+             error_data['PV Calib. (PV NetCF)'] = [
352
+                 abs(cluster_pvs.compare_total(pvs)['error'].mean()),
353
+                 abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
354
+                 abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
355
+             ]
356
+        
357
+         # Create Summary Plot
358
+         summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
359
+        
360
+         fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
361
+         summary_df.plot(kind='bar', ax=ax_summary, grid=True)
362
+         ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
363
+         ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
364
+         ax_summary.tick_params(axis='x', rotation=0)
365
+         plt.tight_layout()
366
+        
367
+         buf_summary = io.BytesIO()
368
+         plt.savefig(buf_summary, format='png', dpi=100)
369
+         buf_summary.seek(0)
370
+         results['summary_plot'] = Image.open(buf_summary)
371
+         plt.close(fig_summary)
372
+        
373
+         return results
374
+        
375
+     except FileNotFoundError as e:
376
+         gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
377
+         return {"error": f"File not found: {e.filename}"}
378
+     except KeyError as e:
379
+         gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
380
+         return {"error": f"Missing column: {e}"}
381
+     except Exception as e:
382
+         gr.Error(f"Error processing files: {str(e)}")
383
+         return {"error": f"Error processing files: {str(e)}"}
384
+
385
 
386
  def create_interface():
387
+     with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
388
+         gr.Markdown("""
389
+         # Cluster Model Points Analysis
390
+        
391
+         This application applies cluster analysis to model point selection for insurance portfolios.
392
+         Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
393
+        
394
+         **Required Files (Excel .xlsx):**
395
+         - Cashflows - Base Scenario
396
+         - Cashflows - Lapse Stress (+50%)
397
+         - Cashflows - Mortality Stress (+15%)
398
+         - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
399
+         - Present Values - Base Scenario
400
+         - Present Values - Lapse Stress
401
+         - Present Values - Mortality Stress
402
+         """)
403
+        
404
+         with gr.Row():
405
+             with gr.Column(scale=1):
406
+                 gr.Markdown("### Upload Files or Load Examples")
407
+                
408
+                 load_example_btn = gr.Button("Load Example Data")
409
+                
410
+                 with gr.Row():
411
+                     cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
412
+                     cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
413
+                     cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
414
+                 with gr.Row():
415
+                     policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
416
+                     pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
417
+                     pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
418
+                 with gr.Row():
419
+                     pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
420
+                
421
+                 analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
422
+        
423
+         with gr.Tabs():
424
+             with gr.TabItem(" Summary"):
425
+                 summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
426
+            
427
+             with gr.TabItem(" Cashflow Calibration"):
428
+                 gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
429
+                 with gr.Row():
430
+                     cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
431
+                     cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
432
+                 cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
433
+                 cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
434
+                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
435
+                     with gr.Row():
436
+                         cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
437
+                         cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
438
+                         cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
439
+            
440
+             with gr.TabItem(" Policy Attribute Calibration"):
441
+                 gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
442
+                 with gr.Row():
443
+                     attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
444
+                     attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
445
+                 attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
446
+                 attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
447
+                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
448
+                      attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
449
+
450
+             with gr.TabItem(" Present Value Calibration"):
451
+                 gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
452
+                 with gr.Row():
453
+                     pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
454
+                     pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
455
+                 pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
456
+                 pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
457
+                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
458
+                     with gr.Row():
459
+                         pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
460
+                         pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
461
+                         pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
462
+
463
+         # --- Helper function to prepare outputs ---
464
+         def get_all_output_components():
465
+             return [
466
+                 summary_plot_output,
467
+                 # Cashflow Calib Outputs
468
+                 cf_total_base_table_out, cf_policy_attrs_total_out,
469
+                 cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
470
+                 cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
471
+                 # Attribute Calib Outputs
472
+                 attr_total_cf_base_out, attr_policy_attrs_total_out,
473
+                 attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
474
+                 # PV Calib Outputs
475
+                 pv_total_cf_base_out, pv_policy_attrs_total_out,
476
+                 pv_cashflow_plot_out, pv_scatter_pvs_base_out,
477
+                 pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
478
+             ]
479
+        
480
+         # --- Action for Analyze Button ---
481
+         def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
482
+             # Ensure all files are provided (either by upload or example load)
483
+             files = [f1, f2, f3, f4, f5, f6, f7]
484
+             # Gradio File objects have a .name attribute for the temp path
485
+             # If they are already strings (from example load), they are paths
486
+            
487
+             file_paths = []
488
+             for i, f_obj in enumerate(files):
489
+                 if f_obj is None:
490
+                     gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
491
+                     # Return Nones for all output components
492
+                     return [None] * len(get_all_output_components())
493
+                
494
+                 # If f_obj is a Gradio FileData object (from direct upload)
495
+                 if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
496
+                     file_paths.append(f_obj.name)
497
+                 # If f_obj is already a string path (from example load)
498
+                 elif isinstance(f_obj, str):
499
+                      file_paths.append(f_obj)
500
+                 else:
501
+                     gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
502
+                     return [None] * len(get_all_output_components())
503
+
504
+
505
+             results = process_files(*file_paths)
506
+            
507
+             if "error" in results:
508
+                 # Error already displayed by process_files or here
509
+                 return [None] * len(get_all_output_components())
510
+            
511
+             return [
512
+                 results.get('summary_plot'),
513
+                 # CF Calib
514
+                 results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
515
+                 results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
516
+                 results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
517
+                 # Attr Calib
518
+                 results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
519
+                 results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
520
+                 # PV Calib
521
+                 results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
522
+                 results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
523
+                 results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
524
+             ]
525
+
526
+         analyze_btn.click(
527
+             handle_analysis,
528
+             inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
529
+                     policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
530
+             outputs=get_all_output_components()
531
+         )
532
+
533
+         # --- Action for Load Example Data Button ---
534
+         def load_example_files():
535
+             # Check if all example files exist
536
+             missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
537
+             if missing_files:
538
+                 gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
539
+                 return [None] * 7 # Return Nones for all file inputs
540
+            
541
+             gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
542
+             return [
543
+                 EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
544
+                 EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
545
+                 EXAMPLE_FILES["pv_mort"]
546
+             ]
547
+
548
+         load_example_btn.click(
549
+             load_example_files,
550
+             inputs=[],
551
+             outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
552
+                      policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
553
+         )
554
+            
555
+     return demo
 
 
 
 
 
556
 
557
  if __name__ == "__main__":
558
+     # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
559
+     if not os.path.exists(EXAMPLE_DATA_DIR):
560
+         os.makedirs(EXAMPLE_DATA_DIR)
561
+         print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
562
+         # You might want to add dummy files here for basic testing if the real files aren't present
563
+         # For example:
564
+         # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
565
+         # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
566
+         # It's better to instruct the user to add the actual files.
567
+         print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
568
+
569
+
570
+     demo_app = create_interface()
571
+     demo_app.launch()