alidenewade commited on
Commit
4e456a7
·
verified ·
1 Parent(s): e82ad24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +574 -542
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
@@ -12,560 +12,592 @@ from PIL import Image
12
  # Define the paths for example data
13
  EXAMPLE_DATA_DIR = "eg_data"
14
  EXAMPLE_FILES = {
15
-     "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
-     "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
-     "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
-     "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
19
-     "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
-     "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
-     "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
22
  }
23
 
24
  class Clusters:
25
-     def __init__(self, loc_vars):
26
-         self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
-         closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
-        
29
-         rep_ids = pd.Series(data=(closest+1))  # 0-based to 1-based indexes
30
-         rep_ids.name = 'policy_id'
31
-         rep_ids.index.name = 'cluster_id'
32
-         self.rep_ids = rep_ids
33
-        
34
-         self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
-
36
-     def agg_by_cluster(self, df, agg=None):
37
-         """Aggregate columns by cluster"""
38
-         temp = df.copy()
39
-         temp['cluster_id'] = self.kmeans.labels_
40
-         temp = temp.set_index('cluster_id')
41
-         agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
42
-         return temp.groupby(temp.index).agg(agg)
43
-
44
-     def extract_reps(self, df):
45
-         """Extract the rows of representative policies"""
46
-         temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
-         temp.index.name = 'cluster_id'
48
-         return temp.drop('policy_id', axis=1)
49
-
50
-     def extract_and_scale_reps(self, df, agg=None):
51
-         """Extract and scale the rows of representative policies"""
52
-         if agg:
53
-             cols = df.columns
54
-             mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
-             # Ensure mult has same index as extract_reps(df) for proper alignment
56
-             extracted_df = self.extract_reps(df)
57
-             mult.index = extracted_df.index
58
-             return extracted_df.mul(mult)
59
-         else:
60
-             return self.extract_reps(df).mul(self.policy_count, axis=0)
61
-
62
-     def compare(self, df, agg=None):
63
-         """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
-         source = self.agg_by_cluster(df, agg)
65
-         target = self.extract_and_scale_reps(df, agg)
66
-         return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
67
-
68
-     def compare_total(self, df, agg=None):
69
-         """Aggregate df by columns"""
70
-         if agg:
71
-             # cols = df.columns # Not used
72
-             op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
73
-             actual = df.agg(op)
74
-            
75
-             # For estimate, ensure aggregation ops are correctly applied *after* scaling
76
-             scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
77
-            
78
-             # Corrected aggregation for estimate when 'mean' is involved
79
-             estimate_agg_ops = {}
80
-             for col_name, agg_type in op.items():
81
-                 if agg_type == 'mean':
82
-                     # Weighted average for mean columns
83
-                     estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
84
-                 else: # 'sum'
85
-                     estimate_agg_ops[col_name] = 'sum'
86
-            
87
-             # Need to handle the case where extract_and_scale_reps already applied scaling for sum
88
-             # The logic in extract_and_scale_reps is:
89
-             # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
90
-             # This means 'mean' columns are NOT multiplied by policy_count initially.
91
-            
92
-             # Let's re-think the estimate aggregation for 'mean'
93
-             estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
94
-            
95
-             final_estimate_ops = {}
96
-             for col, method in op.items():
97
-                 if method == 'mean':
98
-                     # For mean, we need the sum of (value * policy_count) / sum(policy_count)
99
-                     # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
100
-                     # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
101
-                     # sum them up, and divide by total policy_count.
102
-                     # However, the current extract_and_scale_reps scales 'mean' columns by 1.
103
-                     # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
104
-
105
-                     # Let's try to be more direct:
106
-                     # Get the representative policies (unscaled for mean columns)
107
-                     reps_unscaled_for_mean = self.extract_reps(df)
108
-                     estimate_values = {}
109
-                     for c in df.columns:
110
-                         if op[c] == 'sum':
111
-                            estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
112
-                         elif op[c] == 'mean':
113
-                            weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
114
-                            total_weight = self.policy_count.sum()
115
-                            estimate_values[c] = weighted_sum / total_weight if total_weight else 0
116
-                     estimate = pd.Series(estimate_values)
117
-
118
-                 else: # original 'sum' logic for all columns
119
-                     final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
120
-                     estimate = estimate_scaled.agg(final_estimate_ops)
121
-
122
-
123
-         else: # Original logic if no agg is specified (all sum)
124
-             actual = df.sum()
125
-             estimate = self.extract_and_scale_reps(df).sum()
126
-        
127
-         return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
131
-     """Create cashflow comparison plots"""
132
-     if not cfs_list or not cluster_obj or not titles:
133
-         return None # Or a placeholder image
134
-     num_plots = len(cfs_list)
135
-     if num_plots == 0:
136
-         return None
137
-
138
-     # Determine subplot layout (e.g., 2x2 or adapt)
139
-     cols = 2
140
-     rows = (num_plots + cols - 1) // cols
141
-    
142
-     fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
143
-     axes = axes.flatten()
144
-    
145
-     for i, (df, title) in enumerate(zip(cfs_list, titles)):
146
-         if i < len(axes):
147
-             comparison = cluster_obj.compare_total(df)
148
-             comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
149
-             axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
150
-             axes[i].set_ylabel('Value')
151
-    
152
-     # Hide any unused subplots
153
-     for j in range(i + 1, len(axes)):
154
-         fig.delaxes(axes[j])
155
-        
156
-     plt.tight_layout()
157
-     buf = io.BytesIO()
158
-     plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
159
-     buf.seek(0)
160
-     img = Image.open(buf)
161
-     plt.close(fig) # Ensure figure is closed
162
-     return img
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def plot_scatter_comparison(df_compare_output, title):
165
-     """Create scatter plot comparison from compare() output"""
166
-     if df_compare_output is None or df_compare_output.empty:
167
-         # Create a blank plot with a message
168
-         fig, ax = plt.subplots(figsize=(12, 8))
169
-         ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
170
-         ax.set_title(title)
171
-         buf = io.BytesIO()
172
-         plt.savefig(buf, format='png', dpi=100)
173
-         buf.seek(0)
174
-         img = Image.open(buf)
175
-         plt.close(fig)
176
-         return img
177
-
178
-     fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
179
-    
180
-     if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
181
-          gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
182
-          ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
183
-     else:
184
-         unique_levels = df_compare_output.index.get_level_values(1).unique()
185
-         colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
186
-        
187
-         for item_level, color_val in zip(unique_levels, colors):
188
-             subset = df_compare_output.xs(item_level, level=1)
189
-             ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
190
-         if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
191
-             ax.legend(title=df_compare_output.index.names[1])
192
-
193
-
194
-     ax.set_xlabel('Actual')
195
-     ax.set_ylabel('Estimate')
196
-     ax.set_title(title)
197
-     ax.grid(True)
198
-    
199
-     # Draw identity line
200
-     lims = [
201
-         np.min([ax.get_xlim(), ax.get_ylim()]),
202
-         np.max([ax.get_xlim(), ax.get_ylim()]),
203
-     ]
204
-     if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
205
-       ax.plot(lims, lims, 'r-', linewidth=0.5)
206
-       ax.set_xlim(lims)
207
-       ax.set_ylim(lims)
208
-    
209
-     buf = io.BytesIO()
210
-     plt.savefig(buf, format='png', dpi=100)
211
-     buf.seek(0)
212
-     img = Image.open(buf)
213
-     plt.close(fig)
214
-     return img
215
 
216
 
217
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
218
-                   policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
219
-     """Main processing function - now accepts file paths"""
220
-     try:
221
-         # Read uploaded files using paths
222
-         cfs = pd.read_excel(cashflow_base_path, index_col=0)
223
-         cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
224
-         cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
225
-        
226
-         pol_data_full = pd.read_excel(policy_data_path, index_col=0)
227
-         # Ensure the correct columns are selected for pol_data
228
-         required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
229
-         if all(col in pol_data_full.columns for col in required_cols):
230
-             pol_data = pol_data_full[required_cols]
231
-         else:
232
-             # Fallback or error if columns are missing. For now, try to use as is or a subset.
233
-             gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
234
-             pol_data = pol_data_full
235
-
236
-
237
-         pvs = pd.read_excel(pv_base_path, index_col=0)
238
-         pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
239
-         pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
240
-        
241
-         cfs_list = [cfs, cfs_lapse50, cfs_mort15]
242
-         # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
243
-         scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
244
-        
245
-         results = {}
246
-        
247
-         mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
248
-
249
-         # --- 1. Cashflow Calibration ---
250
-         cluster_cfs = Clusters(cfs)
251
-        
252
-         results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
253
-         # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
254
-         # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
255
-
256
-         results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
257
-        
258
-         results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
259
-         results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
260
-         results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
261
-        
262
-         results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
263
-         results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
264
-         # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
265
-         # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
266
-
267
-         # --- 2. Policy Attribute Calibration ---
268
-         # Standardize policy attributes
269
-         if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
270
-              loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
271
-         else:
272
-             gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
273
-             loc_vars_attrs = pol_data # or handle as an error/skip
274
-        
275
-         if not loc_vars_attrs.empty:
276
-             cluster_attrs = Clusters(loc_vars_attrs)
277
-             results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
278
-             results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
279
-             results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
280
-             results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
281
-             results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
282
-             # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
283
-
284
-         else: # Fill with None if skipped
285
-             results['attr_total_cf_base'] = pd.DataFrame()
286
-             results['attr_policy_attrs_total'] = pd.DataFrame()
287
-             results['attr_total_pv_base'] = pd.DataFrame()
288
-             results['attr_cashflow_plot'] = None
289
-             results['attr_scatter_cashflows_base'] = None
290
-
291
-
292
-         # --- 3. Present Value Calibration ---
293
-         cluster_pvs = Clusters(pvs)
294
-        
295
-         results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
296
-         results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
297
-        
298
-         results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
299
-         results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
300
-         results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
301
-        
302
-         results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
303
-         results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
304
-         # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
305
-
306
-
307
-         # --- Summary Comparison Plot Data ---
308
-         # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
309
-         # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
310
-         # For simplicity, using mean of the 'error' column from compare_total for key metrics
311
-        
312
-         error_data = {}
313
-        
314
-         # Cashflow Calibration Errors
315
-         if 'PV_NetCF' in pvs.columns:
316
-             err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
317
-             err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
318
-             err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
319
-             error_data['CF Calib. (PV NetCF)'] = [
320
-                 abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
321
-             ]
322
-         else: # Fallback if PV_NetCF is not present
323
-             error_data['CF Calib. (PV NetCF)'] = [
324
-                 abs(cluster_cfs.compare_total(pvs)['error'].mean()),
325
-                 abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
326
-                 abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
327
-             ]
328
-
329
-
330
-         # Policy Attribute Calibration Errors
331
-         if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
332
-             err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
333
-             err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
334
-             err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
335
-             error_data['Attr Calib. (PV NetCF)'] = [
336
-                 abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
337
-             ]
338
-         else:
339
-              error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
340
-
341
-
342
-         # Present Value Calibration Errors
343
-         if 'PV_NetCF' in pvs.columns:
344
-             err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
345
-             err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
346
-             err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
347
-             error_data['PV Calib. (PV NetCF)'] = [
348
-                 abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
349
-             ]
350
-         else:
351
-             error_data['PV Calib. (PV NetCF)'] = [
352
-                 abs(cluster_pvs.compare_total(pvs)['error'].mean()),
353
-                 abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
354
-                 abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
355
-             ]
356
-        
357
-         # Create Summary Plot
358
-         summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
359
-        
360
-         fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
361
-         summary_df.plot(kind='bar', ax=ax_summary, grid=True)
362
-         ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
363
-         ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
364
-         ax_summary.tick_params(axis='x', rotation=0)
365
-         plt.tight_layout()
366
-        
367
-         buf_summary = io.BytesIO()
368
-         plt.savefig(buf_summary, format='png', dpi=100)
369
-         buf_summary.seek(0)
370
-         results['summary_plot'] = Image.open(buf_summary)
371
-         plt.close(fig_summary)
372
-        
373
-         return results
374
-        
375
-     except FileNotFoundError as e:
376
-         gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
377
-         return {"error": f"File not found: {e.filename}"}
378
-     except KeyError as e:
379
-         gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
380
-         return {"error": f"Missing column: {e}"}
381
-     except Exception as e:
382
-         gr.Error(f"Error processing files: {str(e)}")
383
-         return {"error": f"Error processing files: {str(e)}"}
384
 
385
 
386
  def create_interface():
387
-     with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
388
-         gr.Markdown("""
389
-         # Cluster Model Points Analysis
390
-        
391
-         This application applies cluster analysis to model point selection for insurance portfolios.
392
-         Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
393
-        
394
-         **Required Files (Excel .xlsx):**
395
-         - Cashflows - Base Scenario
396
-         - Cashflows - Lapse Stress (+50%)
397
-         - Cashflows - Mortality Stress (+15%)
398
-         - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
399
-         - Present Values - Base Scenario
400
-         - Present Values - Lapse Stress
401
-         - Present Values - Mortality Stress
402
-         """)
403
-        
404
-         with gr.Row():
405
-             with gr.Column(scale=1):
406
-                 gr.Markdown("### Upload Files or Load Examples")
407
-                
408
-                 load_example_btn = gr.Button("Load Example Data")
409
-                
410
-                 with gr.Row():
411
-                     cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
412
-                     cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
413
-                     cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
414
-                 with gr.Row():
415
-                     policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
416
-                     pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
417
-                     pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
418
-                 with gr.Row():
419
-                     pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
420
-                
421
-                 analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
422
-        
423
-         with gr.Tabs():
424
-             with gr.TabItem(" Summary"):
425
-                 summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
426
-            
427
-             with gr.TabItem(" Cashflow Calibration"):
428
-                 gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
429
-                 with gr.Row():
430
-                     cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
431
-                     cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
432
-                 cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
433
-                 cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
434
-                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
435
-                     with gr.Row():
436
-                         cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
437
-                         cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
438
-                         cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
439
-            
440
-             with gr.TabItem(" Policy Attribute Calibration"):
441
-                 gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
442
-                 with gr.Row():
443
-                     attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
444
-                     attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
445
-                 attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
446
-                 attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
447
-                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
448
-                      attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
449
-
450
-             with gr.TabItem(" Present Value Calibration"):
451
-                 gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
452
-                 with gr.Row():
453
-                     pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
454
-                     pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
455
-                 pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
456
-                 pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
457
-                 with gr.Accordion("Present Value Comparisons (Total)", open=False):
458
-                     with gr.Row():
459
-                         pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
460
-                         pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
461
-                         pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
462
-
463
-         # --- Helper function to prepare outputs ---
464
-         def get_all_output_components():
465
-             return [
466
-                 summary_plot_output,
467
-                 # Cashflow Calib Outputs
468
-                 cf_total_base_table_out, cf_policy_attrs_total_out,
469
-                 cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
470
-                 cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
471
-                 # Attribute Calib Outputs
472
-                 attr_total_cf_base_out, attr_policy_attrs_total_out,
473
-                 attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
474
-                 # PV Calib Outputs
475
-                 pv_total_cf_base_out, pv_policy_attrs_total_out,
476
-                 pv_cashflow_plot_out, pv_scatter_pvs_base_out,
477
-                 pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
478
-             ]
479
-        
480
-         # --- Action for Analyze Button ---
481
-         def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
482
-             # Ensure all files are provided (either by upload or example load)
483
-             files = [f1, f2, f3, f4, f5, f6, f7]
484
-             # Gradio File objects have a .name attribute for the temp path
485
-             # If they are already strings (from example load), they are paths
486
-            
487
-             file_paths = []
488
-             for i, f_obj in enumerate(files):
489
-                 if f_obj is None:
490
-                     gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
491
-                     # Return Nones for all output components
492
-                     return [None] * len(get_all_output_components())
493
-                
494
-                 # If f_obj is a Gradio FileData object (from direct upload)
495
-                 if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
496
-                     file_paths.append(f_obj.name)
497
-                 # If f_obj is already a string path (from example load)
498
-                 elif isinstance(f_obj, str):
499
-                      file_paths.append(f_obj)
500
-                 else:
501
-                     gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
502
-                     return [None] * len(get_all_output_components())
503
-
504
-
505
-             results = process_files(*file_paths)
506
-            
507
-             if "error" in results:
508
-                 # Error already displayed by process_files or here
509
-                 return [None] * len(get_all_output_components())
510
-            
511
-             return [
512
-                 results.get('summary_plot'),
513
-                 # CF Calib
514
-                 results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
515
-                 results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
516
-                 results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
517
-                 # Attr Calib
518
-                 results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
519
-                 results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
520
-                 # PV Calib
521
-                 results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
522
-                 results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
523
-                 results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
524
-             ]
525
-
526
-         analyze_btn.click(
527
-             handle_analysis,
528
-             inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
529
-                     policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
530
-             outputs=get_all_output_components()
531
-         )
532
-
533
-         # --- Action for Load Example Data Button ---
534
-         def load_example_files():
535
-             # Check if all example files exist
536
-             missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
537
-             if missing_files:
538
-                 gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
539
-                 return [None] * 7 # Return Nones for all file inputs
540
-            
541
-             gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
542
-             return [
543
-                 EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
544
-                 EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
545
-                 EXAMPLE_FILES["pv_mort"]
546
-             ]
547
-
548
-         load_example_btn.click(
549
-             load_example_files,
550
-             inputs=[],
551
-             outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
552
-                      policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
553
-         )
554
-            
555
-     return demo
556
 
557
  if __name__ == "__main__":
558
-     # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
559
-     if not os.path.exists(EXAMPLE_DATA_DIR):
560
-         os.makedirs(EXAMPLE_DATA_DIR)
561
-         print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
562
-         # You might want to add dummy files here for basic testing if the real files aren't present
563
-         # For example:
564
-         # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
565
-         # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
566
-         # It's better to instruct the user to add the actual files.
567
-         print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
568
-
569
-
570
-     demo_app = create_interface()
571
-     demo_app.launch()
 
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the final Gradio app logic
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
 
12
  # Define the paths for example data
13
  EXAMPLE_DATA_DIR = "eg_data"
14
  EXAMPLE_FILES = {
15
+ "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
+ "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
+ "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
+ "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"),
19
+ "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
+ "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
+ "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
22
  }
23
 
24
  class Clusters:
25
+ def __init__(self, loc_vars):
26
+ # Ensure loc_vars is not empty before fitting KMeans
27
+ if loc_vars.empty:
28
+ raise ValueError("Input data for KMeans (loc_vars) is empty.")
29
+ if loc_vars.isnull().all().all():
30
+ raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.")
31
+
32
+ self.kmeans = KMeans(n_clusters=min(1000, len(loc_vars)), random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
33
+ closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
34
+
35
+ rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
36
+ rep_ids.name = 'policy_id'
37
+ rep_ids.index.name = 'cluster_id'
38
+ self.rep_ids = rep_ids
39
+
40
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
41
+
42
+ def agg_by_cluster(self, df, agg=None):
43
+ temp = df.copy()
44
+ temp['cluster_id'] = self.kmeans.labels_
45
+ temp = temp.set_index('cluster_id')
46
+
47
+ # Ensure agg is a dictionary if not None
48
+ if agg is not None and not isinstance(agg, dict):
49
+ # Assuming if agg is not a dict, it's the default "sum" for all, which is handled by else.
50
+ # This case might need specific handling if agg can be other types.
51
+ # For now, if it's not a dict, treat as if no specific agg ops were given for columns.
52
+ agg_ops = {col: "sum" for col in temp.columns} # Default to sum if agg format is unexpected
53
+ elif isinstance(agg, dict):
54
+ agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns}
55
+ else: # agg is None
56
+ agg_ops = "sum" # Pandas groupby will apply sum to all numeric columns
57
+
58
+ return temp.groupby(temp.index).agg(agg_ops)
59
+
60
+ def extract_reps(self, df):
61
+ temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
62
+ temp.index.name = 'cluster_id'
63
+ return temp.drop('policy_id', axis=1)
64
+
65
+ def extract_and_scale_reps(self, df, agg=None):
66
+ extracted_df = self.extract_reps(df)
67
+ if extracted_df.empty:
68
+ return extracted_df # Return empty if no representatives
69
+
70
+ if agg and isinstance(agg, dict):
71
+ # mult should be a Series aligned with extracted_df's columns for element-wise multiplication after selection
72
+ # This part of the logic seems to intend to scale rows based on policy_count for 'sum' aggs
73
+ # and leave 'mean' aggs as is (to be weighted later).
74
+ # The original code created a DataFrame `mult` then did .mul(mult).
75
+ # A more direct approach for scaling rows:
76
+ scaled_df = extracted_df.copy()
77
+ for c in extracted_df.columns:
78
+ if agg.get(c, 'sum') == 'sum': # Default to 'sum' if column not in agg
79
+ scaled_df[c] = extracted_df[c].mul(self.policy_count, axis=0)
80
+ # else (it's 'mean'), do not scale by policy_count here.
81
+ return scaled_df
82
+ else: # Default: scale all columns by policy_count (as if for sum)
83
+ return extracted_df.mul(self.policy_count, axis=0)
84
+
85
+ def compare(self, df, agg=None):
86
+ source = self.agg_by_cluster(df, agg)
87
+ target = self.extract_and_scale_reps(df, agg) # This target needs to be aggregated like source
88
+
89
+ # The target from extract_and_scale_reps is already scaled per cluster for 'sum' ops.
90
+ # For 'mean' ops, it's the representative value.
91
+ # We need to sum up the 'sum' columns and calculate weighted average for 'mean' columns.
92
+ if agg and isinstance(agg, dict):
93
+ agg_ops_for_target = {}
94
+ for col, method in agg.items():
95
+ if method == 'sum':
96
+ agg_ops_for_target[col] = 'sum'
97
+ elif method == 'mean':
98
+ # For mean, we need sum(val*count)/sum(count).
99
+ # extract_and_scale_reps DID NOT scale mean columns by policy_count.
100
+ # So, target[col] has rep values. We need to weight them.
101
+ # This is better handled in compare_total. Here, target is per-cluster.
102
+ # This function compares per-cluster values BEFORE final aggregation.
103
+ # So target should represent aggregated values per cluster.
104
+ pass # 'sum' columns are scaled, 'mean' columns are rep values
105
+ else: # all sum
106
+ pass # target is already scaled by policy_count, so it's the sum per cluster
107
+
108
+ # This function is for per-cluster comparison, not total.
109
+ # The 'target' from extract_and_scale_reps already has the representative values scaled by policy_count for sum-like aggregations.
110
+ # If a column is meant for 'mean', it's just the representative value.
111
+ # This 'compare' function might be misinterpreting 'target' if 'agg' has 'mean'.
112
+ # The original notebook's compare function:
113
+ # source = self.agg_by_cluster(df, agg) # Actual sums/means per cluster
114
+ # target = self.extract_and_scale_reps(df, agg) # Rep values, scaled by count if 'sum', unscaled if 'mean'
115
+ # This structure implies 'target' might not be directly comparable if 'mean' is involved without further processing.
116
+ # However, the scatter plots it generates plot these per-cluster values.
117
+ # For 'sum' variables, target is an estimate of the cluster total.
118
+ # For 'mean' variables, target is the rep's value (estimate of cluster mean).
119
+
120
+ return pd.DataFrame({'actual': source.stack(), 'estimate': target.stack()})
121
+
122
+
123
+ def compare_total(self, df, agg=None):
124
+ """Aggregate df by columns and compare actual vs estimate totals."""
125
+ if df.empty:
126
+ return pd.DataFrame(columns=['actual', 'estimate', 'error'])
127
+
128
+ # Determine aggregation operations for each column
129
+ op_for_actual = {}
130
+ if isinstance(agg, dict):
131
+ for c in df.columns:
132
+ op_for_actual[c] = agg.get(c, 'sum') # Default to 'sum' if not in agg
133
+ else: # agg is None or not a dict, apply sum to all
134
+ for c in df.columns:
135
+ if pd.api.types.is_numeric_dtype(df[c]):
136
+ op_for_actual[c] = 'sum'
137
+ # else: non-numeric columns will be ignored by df.agg if op not specified
138
+
139
+ actual = df.agg(op_for_actual)
140
+ actual = actual.dropna() # Remove non-numeric results if any
141
+
142
+ # Calculate estimate
143
+ reps_values = self.extract_reps(df) # Get raw representative values (one per cluster)
144
+ if reps_values.empty: # No representatives found
145
+ estimate = pd.Series(index=actual.index, dtype=float) # Empty or NaN series
146
+ else:
147
+ estimate_values = {}
148
+ for col_name in actual.index: # Iterate over columns that had a valid actual aggregation
149
+ col_op = op_for_actual.get(col_name, 'sum')
150
+
151
+ if col_name not in reps_values.columns: # Should not happen if df columns match
152
+ estimate_values[col_name] = np.nan
153
+ continue
154
+
155
+ rep_col_values = reps_values[col_name]
156
+
157
+ if col_op == 'sum':
158
+ # Estimate for sum is sum of (representative_value * policy_count_for_its_cluster)
159
+ estimate_values[col_name] = (rep_col_values * self.policy_count).sum()
160
+ elif col_op == 'mean':
161
+ # Estimate for mean is weighted average: sum(rep_value * policy_count) / sum(policy_count)
162
+ weighted_sum = (rep_col_values * self.policy_count).sum()
163
+ total_weight = self.policy_count.sum()
164
+ estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan
165
+ else: # Should not happen given op_for_actual logic
166
+ estimate_values[col_name] = np.nan
167
+
168
+ estimate = pd.Series(estimate_values, index=actual.index) # Align with actual's index
169
+
170
+ # Calculate error
171
+ # Align actual and estimate to ensure they cover the same items for error calculation
172
+ actual_aligned, estimate_aligned = actual.align(estimate, join='inner')
173
+
174
+ error = pd.Series(index=actual_aligned.index, dtype=float)
175
+
176
+ # Valid division where actual is not zero and not NaN
177
+ valid_mask = (actual_aligned != 0) & (~actual_aligned.isna())
178
+ error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1
179
+
180
+ # Where actual is zero (and not NaN)
181
+ actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna())
182
+ # If estimate is also zero, error is 0
183
+ error[actual_zero_mask & (estimate_aligned == 0)] = 0
184
+ # If estimate is non-zero and actual is zero, error is effectively infinite
185
+ error[actual_zero_mask & (estimate_aligned != 0)] = np.inf
186
+
187
+ # Replace any infinities with NaN for cleaner results (e.g., for .mean())
188
+ error = error.replace([np.inf, -np.inf], np.nan)
189
+
190
+ result_df = pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error})
191
+ return result_df
192
 
193
 
194
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
195
+ if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0:
196
+ fig, ax = plt.subplots()
197
+ ax.text(0.5, 0.5, "No data for cashflow comparison plot.", ha='center', va='center')
198
+ buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
199
+
200
+ num_plots = len(cfs_list)
201
+ cols = 2
202
+ rows = (num_plots + cols - 1) // cols
203
+
204
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
205
+ axes = axes.flatten()
206
+
207
+ plot_made = False
208
+ for i, (df_cf, title) in enumerate(zip(cfs_list, titles)):
209
+ if i < len(axes):
210
+ if df_cf is None or df_cf.empty:
211
+ axes[i].text(0.5,0.5, f"No data for {title}", ha='center', va='center')
212
+ axes[i].set_title(title)
213
+ continue
214
+ comparison = cluster_obj.compare_total(df_cf) # Default is sum for all columns
215
+ if not comparison.empty and 'actual' in comparison and 'estimate' in comparison:
216
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
217
+ axes[i].set_xlabel('Time')
218
+ axes[i].set_ylabel('Value')
219
+ plot_made = True
220
+ else:
221
+ axes[i].text(0.5,0.5, f"Could not generate comparison for {title}", ha='center', va='center')
222
+ axes[i].set_title(title)
223
+
224
+ for j in range(i + 1, len(axes)): # Hide unused subplots
225
+ fig.delaxes(axes[j])
226
+
227
+ if not plot_made: # If no plots were actually made (e.g. all data was empty)
228
+ plt.close(fig) # Close the figure
229
+ fig, ax = plt.subplots() # Create a new one for the message
230
+ ax.text(0.5, 0.5, "Insufficient data for any cashflow plots.", ha='center', va='center')
231
+
232
+
233
+ plt.tight_layout()
234
+ buf = io.BytesIO()
235
+ plt.savefig(buf, format='png', dpi=100)
236
+ buf.seek(0)
237
+ img = Image.open(buf)
238
+ plt.close(fig)
239
+ return img
240
 
241
  def plot_scatter_comparison(df_compare_output, title):
242
+ if df_compare_output is None or df_compare_output.empty:
243
+ fig, ax = plt.subplots(figsize=(10,6)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title)
244
+ buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
245
+
246
+ fig, ax = plt.subplots(figsize=(10, 6))
247
+
248
+ if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
249
+ # This case indicates df_compare_output is not from cluster_obj.compare() as expected
250
+ ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6)
251
+ else:
252
+ unique_levels = df_compare_output.index.get_level_values(1).unique()
253
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
254
+
255
+ for item_level, color_val in zip(unique_levels, colors):
256
+ subset = df_compare_output.xs(item_level, level=1)
257
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
258
+ if len(unique_levels) > 1 and len(unique_levels) <=10:
259
+ ax.legend(title=df_compare_output.index.names[1])
260
+
261
+ ax.set_xlabel('Actual')
262
+ ax.set_ylabel('Estimate')
263
+ ax.set_title(title)
264
+ ax.grid(True)
265
+
266
+ try:
267
+ current_xlim = ax.get_xlim()
268
+ current_ylim = ax.get_ylim()
269
+ lims = [
270
+ np.nanmin([current_xlim, current_ylim]),
271
+ np.nanmax([current_xlim, current_ylim]),
272
+ ]
273
+ if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]):
274
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
275
+ ax.set_xlim(lims)
276
+ ax.set_ylim(lims)
277
+ except Exception: # Catch errors if lims are problematic (e.g. all NaNs)
278
+ pass
279
+
280
+ buf = io.BytesIO()
281
+ plt.savefig(buf, format='png', dpi=100)
282
+ buf.seek(0)
283
+ img = Image.open(buf)
284
+ plt.close(fig)
285
+ return img
 
 
 
 
 
 
286
 
287
 
288
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
289
+ policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
290
+ results = {}
291
+ try:
292
+ cfs = pd.read_excel(cashflow_base_path, index_col=0)
293
+ cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
294
+ cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
295
+
296
+ pol_data_full = pd.read_excel(policy_data_path, index_col=0)
297
+ required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
298
+ missing_policy_cols = [col for col in required_cols if col not in pol_data_full.columns]
299
+ if missing_policy_cols:
300
+ gr.Warning(f"Policy data is missing required columns: {', '.join(missing_policy_cols)}. Analysis may be affected.")
301
+ pol_data = pol_data_full # Use what's available
302
+ else:
303
+ pol_data = pol_data_full[required_cols]
304
+
305
+ pvs = pd.read_excel(pv_base_path, index_col=0)
306
+ pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
307
+ pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
308
+
309
+ cfs_list = [cfs, cfs_lapse50, cfs_mort15]
310
+ scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
311
+
312
+ mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
313
+
314
+ # --- 1. Cashflow Calibration ---
315
+ gr.Info("Starting Cashflow Calibration...")
316
+ if cfs.empty: gr.Warning("Base cashflow data is empty for Cashflow Calibration.")
317
+ cluster_cfs = Clusters(cfs)
318
+ results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
319
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg)
320
+ results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
321
+ results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
322
+ results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
323
+ results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
324
+ results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)')
325
+ gr.Info("Cashflow Calibration Done.")
326
+
327
+ # --- 2. Policy Attribute Calibration ---
328
+ gr.Info("Starting Policy Attribute Calibration...")
329
+ if pol_data.empty :
330
+ gr.Warning("Policy data is empty. Skipping Policy Attribute Calibration.")
331
+ loc_vars_attrs = pd.DataFrame() # Empty dataframe
332
+ else:
333
+ pol_data_min = pol_data.min()
334
+ pol_data_range = pol_data.max() - pol_data_min
335
+ # Avoid division by zero if a column has no variance (all values are the same)
336
+ if (pol_data_range == 0).any():
337
+ gr.Warning("Some policy attributes have no variance (all values are the same). Standardization might be affected.")
338
+ # For columns with zero range, standardized value becomes 0 or NaN depending on pandas version.
339
+ # A common approach is to set them to 0 or handle them separately.
340
+ # Here, we proceed, but pandas might produce NaNs if (val - min) / 0 occurs.
341
+ # Let's ensure range is not zero for division:
342
+ pol_data_range[pol_data_range == 0] = 1 # Avoid division by zero, effectively making constant columns 0 after (x-min)/1
343
+ loc_vars_attrs = (pol_data - pol_data_min) / pol_data_range
344
+ loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle any NaNs from perfect constant columns
345
+
346
+ if not loc_vars_attrs.empty:
347
+ cluster_attrs = Clusters(loc_vars_attrs)
348
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
349
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg)
350
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
351
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
352
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)')
353
+ else:
354
+ results.update({k: pd.DataFrame() for k in ['attr_total_cf_base', 'attr_policy_attrs_total', 'attr_total_pv_base']})
355
+ results.update({k: None for k in ['attr_cashflow_plot', 'attr_scatter_cashflows_base']})
356
+ gr.Info("Policy Attribute Calibration Done.")
357
+
358
+ # --- 3. Present Value Calibration ---
359
+ gr.Info("Starting Present Value Calibration...")
360
+ if pvs.empty: gr.Warning("Base Present Value data is empty for PV Calibration.")
361
+ cluster_pvs = Clusters(pvs)
362
+ results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
363
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg)
364
+ results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
365
+ results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
366
+ results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
367
+ results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
368
+ results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
369
+ gr.Info("Present Value Calibration Done.")
370
+
371
+ # --- Summary Comparison Plot Data ---
372
+ gr.Info("Generating Summary Plot...")
373
+ error_data = {}
374
+ pv_col_name = 'PV_NetCF' # Target column for summary
375
+
376
+ for calib_prefix, cluster_obj, calib_name_display in [
377
+ ('CF Calib.', cluster_cfs, "CF Calib."),
378
+ ('Attr Calib.', globals().get('cluster_attrs'), "Attr Calib."),
379
+ ('PV Calib.', cluster_pvs, "PV Calib.")]:
380
+
381
+ current_calib_errors = []
382
+ if cluster_obj is None and calib_prefix == 'Attr Calib.': # Attr calib might be skipped
383
+ current_calib_errors = [np.nan, np.nan, np.nan]
384
+ else:
385
+ for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]:
386
+ if pv_df_scenario.empty:
387
+ current_calib_errors.append(np.nan)
388
+ continue
389
+
390
+ comp_total_df = cluster_obj.compare_total(pv_df_scenario)
391
+ if pv_col_name in comp_total_df.index:
392
+ error_val = comp_total_df.loc[pv_col_name, 'error']
393
+ elif not comp_total_df.empty and 'error' in comp_total_df.columns:
394
+ error_val = comp_total_df['error'].mean() # Fallback
395
+ if calib_prefix == 'CF Calib.' and pv_df_scenario is pvs: # Only warn once per type if fallback
396
+ gr.Warning(f"'{pv_col_name}' not found for summary plot. Using mean error of all PV columns instead for {calib_name_display}.")
397
+ else: # comp_total_df is empty or no 'error' column
398
+ error_val = np.nan
399
+ current_calib_errors.append(abs(error_val))
400
+ error_data[calib_name_display] = current_calib_errors
401
+
402
+ summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
403
+
404
+ fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
405
+
406
+ plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}'
407
+ if summary_df.isnull().all().all():
408
+ ax_summary.text(0.5, 0.5, f"Error data for summary is N/A.\nCheck input PV files for '{pv_col_name}' column and valid numeric data.",
409
+ ha='center', va='center', transform=ax_summary.transAxes, wrap=True)
410
+ ax_summary.set_title(plot_title)
411
+ elif summary_df.empty:
412
+ ax_summary.text(0.5, 0.5, "No summary data to plot.", ha='center', va='center')
413
+ ax_summary.set_title(plot_title)
414
+ else:
415
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
416
+ ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)')
417
+ ax_summary.set_title(plot_title)
418
+ ax_summary.tick_params(axis='x', rotation=0)
419
+
420
+ plt.tight_layout()
421
+ buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=100); buf_summary.seek(0)
422
+ results['summary_plot'] = Image.open(buf_summary)
423
+ plt.close(fig_summary)
424
+ gr.Info("All processing complete.")
425
+ return results
426
+
427
+ except FileNotFoundError as e:
428
+ gr.Error(f"File not found: {e.filename}. Ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded correctly.")
429
+ return {"error": f"File not found: {e.filename}"}
430
+ except ValueError as e: # Catch specific errors like empty data for KMeans
431
+ gr.Error(f"Data validation error: {str(e)}")
432
+ return {"error": f"Data error: {str(e)}"}
433
+ except KeyError as e:
434
+ gr.Error(f"A required column is missing: {e}. Please check data formats, especially index columns and expected data columns like 'PV_NetCF'.")
435
+ return {"error": f"Missing column: {e}"}
436
+ except Exception as e:
437
+ gr.Error(f"An unexpected error occurred during processing: {str(e)}")
438
+ import traceback
439
+ traceback.print_exc() # Print full traceback to console for debugging
440
+ return {"error": f"Processing error: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
 
443
  def create_interface():
444
+ with gr.Blocks(title="Cluster Model Points Analysis") as demo:
445
+ gr.Markdown("""
446
+ # Cluster Model Points Analysis
447
+ This application applies k-means cluster analysis to select representative model points from an insurance portfolio.
448
+ Upload your Excel files or use the example data to analyze results based on different calibration variable choices.
449
+ **Required Excel (.xlsx) Files:**
450
+ - Cashflows - Base Scenario
451
+ - Cashflows - Lapse Stress (+50%)
452
+ - Cashflows - Mortality Stress (+15%)
453
+ - Policy Data (must include 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth', and an index column for `policy_id`)
454
+ - Present Values - Base Scenario (ideally with a 'PV_NetCF' column and an index column for `policy_id`)
455
+ - Present Values - Lapse Stress (same structure as Base PV)
456
+ - Present Values - Mortality Stress (same structure as Base PV)
457
+ """)
458
+
459
+ with gr.Row():
460
+ with gr.Column(scale=1):
461
+ gr.Markdown("### 📂 Upload Files or Load Examples")
462
+ load_example_btn = gr.Button("Load Example Data", icon="💾")
463
+ with gr.Row():
464
+ cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
465
+ cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
466
+ cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
467
+ with gr.Row():
468
+ policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
469
+ pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
470
+ pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
471
+ with gr.Row():
472
+ pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
473
+ span_dummy = gr.File(visible=False) # For layout balance if needed
474
+ span_dummy2 = gr.File(visible=False)
475
+
476
+
477
+ analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", scale=1)
478
+
479
+ with gr.Tabs():
480
+ with gr.TabItem("📊 Summary"):
481
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison")
482
+
483
+ with gr.TabItem("💸 Cashflow Calibration"):
484
+ gr.Markdown("### Results: Using Annual Cashflows (Base) as Calibration Variables")
485
+ with gr.Row():
486
+ cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
487
+ cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
488
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
489
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
490
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
491
+ with gr.Row():
492
+ cf_pv_total_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
493
+ cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
494
+ cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
495
+
496
+ with gr.TabItem("👤 Policy Attribute Calibration"):
497
+ gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
498
+ with gr.Row():
499
+ attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
500
+ attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
501
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
502
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
503
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
504
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario", wrap=True)
505
+
506
+ with gr.TabItem("💰 Present Value Calibration"):
507
+ gr.Markdown("### Results: Using Present Values (Base) as Calibration Variables")
508
+ with gr.Row():
509
+ pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
510
+ pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
511
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
512
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter: Per-Cluster PVs (Base)")
513
+ with gr.Accordion("Present Value Comparisons (Totals)", open=False):
514
+ with gr.Row():
515
+ pv_total_pv_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
516
+ pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
517
+ pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
518
+
519
+ output_components = [
520
+ summary_plot_output,
521
+ cf_total_base_table_out, cf_policy_attrs_total_out, cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
522
+ cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
523
+ attr_total_cf_base_out, attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
524
+ pv_total_cf_base_out, pv_policy_attrs_total_out, pv_cashflow_plot_out, pv_scatter_pvs_base_out,
525
+ pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
526
+ ]
527
+
528
+ def handle_analysis_click(f1, f2, f3, f4, f5, f6, f7):
529
+ all_files_present = all(f is not None for f in [f1, f2, f3, f4, f5, f6, f7])
530
+ if not all_files_present:
531
+ gr.Warning("Not all files have been provided. Please upload all 7 files or load example data.")
532
+ return [None] * len(output_components) # Return Nones for all output components
533
+
534
+ # file objects (f1, etc.) from gr.File are TemporaryFileWrapper or string paths if loaded by example
535
+ file_paths = []
536
+ for f_obj in [f1, f2, f3, f4, f5, f6, f7]:
537
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Uploaded file
538
+ file_paths.append(f_obj.name)
539
+ elif isinstance(f_obj, str): # Path from example load
540
+ file_paths.append(f_obj)
541
+ else: # Should not happen if files are present
542
+ gr.Error(f"Invalid file input: {f_obj}. Please re-upload or reload examples.")
543
+ return [None] * len(output_components)
544
+
545
+ analysis_results = process_files(*file_paths)
546
+
547
+ if "error" in analysis_results: # Error handled and displayed by process_files
548
+ return [None] * len(output_components)
549
+
550
+ # Map results to output components
551
+ return [
552
+ analysis_results.get('summary_plot'),
553
+ analysis_results.get('cf_total_base_table'), analysis_results.get('cf_policy_attrs_total'),
554
+ analysis_results.get('cf_cashflow_plot'), analysis_results.get('cf_scatter_cashflows_base'),
555
+ analysis_results.get('cf_pv_total_base'), analysis_results.get('cf_pv_total_lapse'), analysis_results.get('cf_pv_total_mort'),
556
+ analysis_results.get('attr_total_cf_base'), analysis_results.get('attr_policy_attrs_total'),
557
+ analysis_results.get('attr_cashflow_plot'), analysis_results.get('attr_scatter_cashflows_base'), analysis_results.get('attr_total_pv_base'),
558
+ analysis_results.get('pv_total_cf_base'), analysis_results.get('pv_policy_attrs_total'),
559
+ analysis_results.get('pv_cashflow_plot'), analysis_results.get('pv_scatter_pvs_base'),
560
+ analysis_results.get('pv_total_pv_base'), analysis_results.get('pv_total_pv_lapse'), analysis_results.get('pv_total_pv_mort')
561
+ ]
562
+
563
+ analyze_btn.click(
564
+ handle_analysis_click,
565
+ inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
566
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
567
+ outputs=output_components
568
+ )
569
+
570
+ input_file_components = [
571
+ cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
572
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input
573
+ ]
574
+ def load_example_files_action():
575
+ missing_example_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
576
+ if missing_example_files:
577
+ gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_example_files)}. Please ensure they exist.")
578
+ return [None] * len(input_file_components)
579
+ gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.")
580
+ return [
581
+ EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
582
+ EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
583
+ EXAMPLE_FILES["pv_mort"]
584
+ ]
585
+ load_example_btn.click(load_example_files_action, inputs=[], outputs=input_file_components)
586
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  if __name__ == "__main__":
589
+ if not os.path.exists(EXAMPLE_DATA_DIR):
590
+ try:
591
+ os.makedirs(EXAMPLE_DATA_DIR)
592
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
593
+ print(f"Expected files: {list(EXAMPLE_FILES.keys())}")
594
+ except OSError as e:
595
+ print(f"Error creating directory {EXAMPLE_DATA_DIR}: {e}. Please create it manually.")
596
+
597
+ print("Starting Gradio application...")
598
+ print(f"Note: Ensure your example Excel files are placed in the '{os.getcwd()}{os.sep}{EXAMPLE_DATA_DIR}' folder.")
599
+ print(f"Required policy data columns: 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' (and an index col).")
600
+ print(f"Recommended PV files column for summary: 'PV_NetCF' (and an index col).")
601
+
602
+ demo_app = create_interface()
603
+ demo_app.launch()