alidenewade commited on
Commit
e0be832
·
verified ·
1 Parent(s): 15cf6ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -134
app.py CHANGED
@@ -15,7 +15,7 @@ EXAMPLE_FILES = {
15
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
  "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
  "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
- "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
19
  "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
  "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
  "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
@@ -68,85 +68,60 @@ class Clusters:
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
- # cols = df.columns # Not used
72
- op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
73
- actual = df.agg(op)
 
 
 
 
 
74
 
75
- # For estimate, ensure aggregation ops are correctly applied *after* scaling
76
- scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
 
77
 
78
- # Corrected aggregation for estimate when 'mean' is involved
79
- estimate_agg_ops = {}
80
- for col_name, agg_type in op.items():
81
- if agg_type == 'mean':
82
  # Weighted average for mean columns
83
- estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
84
- else: # 'sum'
85
- estimate_agg_ops[col_name] = 'sum'
 
 
86
 
87
- # Need to handle the case where extract_and_scale_reps already applied scaling for sum
88
- # The logic in extract_and_scale_reps is:
89
- # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
90
- # This means 'mean' columns are NOT multiplied by policy_count initially.
91
 
92
- # Let's re-think the estimate aggregation for 'mean'
93
- estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
94
-
95
- final_estimate_ops = {}
96
- for col, method in op.items():
97
- if method == 'mean':
98
- # For mean, we need the sum of (value * policy_count) / sum(policy_count)
99
- # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
100
- # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
101
- # sum them up, and divide by total policy_count.
102
- # However, the current extract_and_scale_reps scales 'mean' columns by 1.
103
- # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
104
-
105
- # Let's try to be more direct:
106
- # Get the representative policies (unscaled for mean columns)
107
- reps_unscaled_for_mean = self.extract_reps(df)
108
- estimate_values = {}
109
- for c in df.columns:
110
- if op[c] == 'sum':
111
- estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
112
- elif op[c] == 'mean':
113
- weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
114
- total_weight = self.policy_count.sum()
115
- estimate_values[c] = weighted_sum / total_weight if total_weight else 0
116
- estimate = pd.Series(estimate_values)
117
-
118
- else: # original 'sum' logic for all columns
119
- final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
120
- estimate = estimate_scaled.agg(final_estimate_ops)
121
-
122
-
123
- else: # Original logic if no agg is specified (all sum)
124
  actual = df.sum()
125
  estimate = self.extract_and_scale_reps(df).sum()
126
 
127
- return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
 
 
 
128
 
129
 
130
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
131
  """Create cashflow comparison plots"""
132
  if not cfs_list or not cluster_obj or not titles:
133
- return None # Or a placeholder image
134
  num_plots = len(cfs_list)
135
  if num_plots == 0:
136
  return None
137
 
138
- # Determine subplot layout (e.g., 2x2 or adapt)
139
  cols = 2
140
  rows = (num_plots + cols - 1) // cols
141
 
142
- fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
143
  axes = axes.flatten()
144
 
145
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
146
  if i < len(axes):
147
  comparison = cluster_obj.compare_total(df)
148
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
149
- axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
150
  axes[i].set_ylabel('Value')
151
 
152
  # Hide any unused subplots
@@ -155,10 +130,10 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
155
 
156
  plt.tight_layout()
157
  buf = io.BytesIO()
158
- plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
159
  buf.seek(0)
160
  img = Image.open(buf)
161
- plt.close(fig) # Ensure figure is closed
162
  return img
163
 
164
  def plot_scatter_comparison(df_compare_output, title):
@@ -175,7 +150,7 @@ def plot_scatter_comparison(df_compare_output, title):
175
  plt.close(fig)
176
  return img
177
 
178
- fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
179
 
180
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
181
  gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
@@ -187,10 +162,9 @@ def plot_scatter_comparison(df_compare_output, title):
187
  for item_level, color_val in zip(unique_levels, colors):
188
  subset = df_compare_output.xs(item_level, level=1)
189
  ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
190
- if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
191
  ax.legend(title=df_compare_output.index.names[1])
192
 
193
-
194
  ax.set_xlabel('Actual')
195
  ax.set_ylabel('Estimate')
196
  ax.set_title(title)
@@ -201,7 +175,7 @@ def plot_scatter_comparison(df_compare_output, title):
201
  np.min([ax.get_xlim(), ax.get_ylim()]),
202
  np.max([ax.get_xlim(), ax.get_ylim()]),
203
  ]
204
- if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
205
  ax.plot(lims, lims, 'r-', linewidth=0.5)
206
  ax.set_xlim(lims)
207
  ax.set_ylim(lims)
@@ -229,30 +203,24 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
229
  if all(col in pol_data_full.columns for col in required_cols):
230
  pol_data = pol_data_full[required_cols]
231
  else:
232
- # Fallback or error if columns are missing. For now, try to use as is or a subset.
233
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
234
  pol_data = pol_data_full
235
 
236
-
237
  pvs = pd.read_excel(pv_base_path, index_col=0)
238
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
239
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
240
 
241
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
242
- # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
243
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
244
 
245
  results = {}
246
 
247
- mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
248
 
249
  # --- 1. Cashflow Calibration ---
250
  cluster_cfs = Clusters(cfs)
251
 
252
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
253
- # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
254
- # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
255
-
256
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
257
 
258
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
@@ -261,16 +229,14 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
261
 
262
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
263
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
264
- # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
265
- # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
266
 
267
  # --- 2. Policy Attribute Calibration ---
268
  # Standardize policy attributes
269
- if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
270
  loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
271
  else:
272
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
273
- loc_vars_attrs = pol_data # or handle as an error/skip
274
 
275
  if not loc_vars_attrs.empty:
276
  cluster_attrs = Clusters(loc_vars_attrs)
@@ -279,16 +245,13 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
279
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
280
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
281
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
282
- # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
283
-
284
- else: # Fill with None if skipped
285
  results['attr_total_cf_base'] = pd.DataFrame()
286
  results['attr_policy_attrs_total'] = pd.DataFrame()
287
  results['attr_total_pv_base'] = pd.DataFrame()
288
  results['attr_cashflow_plot'] = None
289
  results['attr_scatter_cashflows_base'] = None
290
 
291
-
292
  # --- 3. Present Value Calibration ---
293
  cluster_pvs = Clusters(pvs)
294
 
@@ -301,67 +264,63 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
301
 
302
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
303
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
304
- # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
305
-
306
 
307
  # --- Summary Comparison Plot Data ---
308
- # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
309
- # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
310
- # For simplicity, using mean of the 'error' column from compare_total for key metrics
311
 
312
  error_data = {}
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  # Cashflow Calibration Errors
315
- if 'PV_NetCF' in pvs.columns:
316
- err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
317
- err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
318
- err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
319
- error_data['CF Calib. (PV NetCF)'] = [
320
- abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
321
- ]
322
- else: # Fallback if PV_NetCF is not present
323
- error_data['CF Calib. (PV NetCF)'] = [
324
- abs(cluster_cfs.compare_total(pvs)['error'].mean()),
325
- abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
326
- abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
327
- ]
328
-
329
 
330
  # Policy Attribute Calibration Errors
331
- if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
332
- err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
333
- err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
334
- err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
335
- error_data['Attr Calib. (PV NetCF)'] = [
336
- abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
337
  ]
338
  else:
339
- error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
340
-
341
 
342
  # Present Value Calibration Errors
343
- if 'PV_NetCF' in pvs.columns:
344
- err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
345
- err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
346
- err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
347
- error_data['PV Calib. (PV NetCF)'] = [
348
- abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
349
- ]
350
- else:
351
- error_data['PV Calib. (PV NetCF)'] = [
352
- abs(cluster_pvs.compare_total(pvs)['error'].mean()),
353
- abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
354
- abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
355
- ]
356
 
357
  # Create Summary Plot
358
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
359
 
360
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
361
  summary_df.plot(kind='bar', ax=ax_summary, grid=True)
362
- ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
363
- ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
 
364
  ax_summary.tick_params(axis='x', rotation=0)
 
365
  plt.tight_layout()
366
 
367
  buf_summary = io.BytesIO()
@@ -384,7 +343,7 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
384
 
385
 
386
  def create_interface():
387
- with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
388
  gr.Markdown("""
389
  # Cluster Model Points Analysis
390
 
@@ -422,7 +381,7 @@ def create_interface():
422
 
423
  with gr.Tabs():
424
  with gr.TabItem("📊 Summary"):
425
- summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
426
 
427
  with gr.TabItem("💸 Cashflow Calibration"):
428
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
@@ -479,16 +438,12 @@ def create_interface():
479
 
480
  # --- Action for Analyze Button ---
481
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
482
- # Ensure all files are provided (either by upload or example load)
483
  files = [f1, f2, f3, f4, f5, f6, f7]
484
- # Gradio File objects have a .name attribute for the temp path
485
- # If they are already strings (from example load), they are paths
486
 
487
  file_paths = []
488
  for i, f_obj in enumerate(files):
489
  if f_obj is None:
490
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
491
- # Return Nones for all output components
492
  return [None] * len(get_all_output_components())
493
 
494
  # If f_obj is a Gradio FileData object (from direct upload)
@@ -501,11 +456,9 @@ def create_interface():
501
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
502
  return [None] * len(get_all_output_components())
503
 
504
-
505
  results = process_files(*file_paths)
506
 
507
  if "error" in results:
508
- # Error already displayed by process_files or here
509
  return [None] * len(get_all_output_components())
510
 
511
  return [
@@ -532,11 +485,10 @@ def create_interface():
532
 
533
  # --- Action for Load Example Data Button ---
534
  def load_example_files():
535
- # Check if all example files exist
536
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
537
  if missing_files:
538
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
539
- return [None] * 7 # Return Nones for all file inputs
540
 
541
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
542
  return [
@@ -555,17 +507,10 @@ def create_interface():
555
  return demo
556
 
557
  if __name__ == "__main__":
558
- # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
559
  if not os.path.exists(EXAMPLE_DATA_DIR):
560
  os.makedirs(EXAMPLE_DATA_DIR)
561
  print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
562
- # You might want to add dummy files here for basic testing if the real files aren't present
563
- # For example:
564
- # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
565
- # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
566
- # It's better to instruct the user to add the actual files.
567
  print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
568
 
569
-
570
  demo_app = create_interface()
571
  demo_app.launch()
 
15
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
  "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
  "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
+ "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"),
19
  "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
  "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
  "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
 
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
+ # Calculate actual values using specified aggregation
72
+ actual_values = {}
73
+ for col in df.columns:
74
+ if agg.get(col, 'sum') == 'mean':
75
+ actual_values[col] = df[col].mean()
76
+ else: # sum
77
+ actual_values[col] = df[col].sum()
78
+ actual = pd.Series(actual_values)
79
 
80
+ # Calculate estimate values
81
+ reps_unscaled = self.extract_reps(df)
82
+ estimate_values = {}
83
 
84
+ for col in df.columns:
85
+ if agg.get(col, 'sum') == 'mean':
 
 
86
  # Weighted average for mean columns
87
+ weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
+ total_weight = self.policy_count.sum()
89
+ estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
90
+ else: # sum
91
+ estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
92
 
93
+ estimate = pd.Series(estimate_values)
 
 
 
94
 
95
+ else: # Original logic if no agg is specified (all sum)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
+ # Calculate error, handling division by zero
100
+ error = np.where(actual != 0, estimate / actual - 1, 0)
101
+
102
+ return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
 
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
  """Create cashflow comparison plots"""
107
  if not cfs_list or not cluster_obj or not titles:
108
+ return None
109
  num_plots = len(cfs_list)
110
  if num_plots == 0:
111
  return None
112
 
113
+ # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
116
 
117
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
118
  axes = axes.flatten()
119
 
120
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
  if i < len(axes):
122
  comparison = cluster_obj.compare_total(df)
123
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
124
+ axes[i].set_xlabel('Time')
125
  axes[i].set_ylabel('Value')
126
 
127
  # Hide any unused subplots
 
130
 
131
  plt.tight_layout()
132
  buf = io.BytesIO()
133
+ plt.savefig(buf, format='png', dpi=100)
134
  buf.seek(0)
135
  img = Image.open(buf)
136
+ plt.close(fig)
137
  return img
138
 
139
  def plot_scatter_comparison(df_compare_output, title):
 
150
  plt.close(fig)
151
  return img
152
 
153
+ fig, ax = plt.subplots(figsize=(12, 8))
154
 
155
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
  gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
 
162
  for item_level, color_val in zip(unique_levels, colors):
163
  subset = df_compare_output.xs(item_level, level=1)
164
  ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
+ if len(unique_levels) > 1 and len(unique_levels) <= 10:
166
  ax.legend(title=df_compare_output.index.names[1])
167
 
 
168
  ax.set_xlabel('Actual')
169
  ax.set_ylabel('Estimate')
170
  ax.set_title(title)
 
175
  np.min([ax.get_xlim(), ax.get_ylim()]),
176
  np.max([ax.get_xlim(), ax.get_ylim()]),
177
  ]
178
+ if lims[0] != lims[1]:
179
  ax.plot(lims, lims, 'r-', linewidth=0.5)
180
  ax.set_xlim(lims)
181
  ax.set_ylim(lims)
 
203
  if all(col in pol_data_full.columns for col in required_cols):
204
  pol_data = pol_data_full[required_cols]
205
  else:
 
206
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
207
  pol_data = pol_data_full
208
 
 
209
  pvs = pd.read_excel(pv_base_path, index_col=0)
210
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
211
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
212
 
213
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
 
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
215
 
216
  results = {}
217
 
218
+ mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
220
  # --- 1. Cashflow Calibration ---
221
  cluster_cfs = Clusters(cfs)
222
 
223
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
 
 
 
224
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
225
 
226
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
 
229
 
230
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
231
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
 
 
232
 
233
  # --- 2. Policy Attribute Calibration ---
234
  # Standardize policy attributes
235
+ if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
236
  loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
237
  else:
238
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
+ loc_vars_attrs = pol_data
240
 
241
  if not loc_vars_attrs.empty:
242
  cluster_attrs = Clusters(loc_vars_attrs)
 
245
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
246
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
247
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
248
+ else:
 
 
249
  results['attr_total_cf_base'] = pd.DataFrame()
250
  results['attr_policy_attrs_total'] = pd.DataFrame()
251
  results['attr_total_pv_base'] = pd.DataFrame()
252
  results['attr_cashflow_plot'] = None
253
  results['attr_scatter_cashflows_base'] = None
254
 
 
255
  # --- 3. Present Value Calibration ---
256
  cluster_pvs = Clusters(pvs)
257
 
 
264
 
265
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
266
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
 
 
267
 
268
  # --- Summary Comparison Plot Data ---
269
+ # Error metric for key PV column or mean absolute error
 
 
270
 
271
  error_data = {}
272
 
273
+ # Function to safely get error value
274
+ def get_error_safe(compare_result, col_name=None):
275
+ if compare_result.empty:
276
+ return np.nan
277
+ if col_name and col_name in compare_result.index:
278
+ return abs(compare_result.loc[col_name, 'error'])
279
+ else:
280
+ # Use mean absolute error if specific column not found
281
+ return abs(compare_result['error']).mean()
282
+
283
+ # Determine key PV column (try common names)
284
+ key_pv_col = None
285
+ for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
286
+ if potential_col in pvs.columns:
287
+ key_pv_col = potential_col
288
+ break
289
+
290
  # Cashflow Calibration Errors
291
+ error_data['CF Calib.'] = [
292
+ get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
293
+ get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
294
+ get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
295
+ ]
 
 
 
 
 
 
 
 
 
296
 
297
  # Policy Attribute Calibration Errors
298
+ if not loc_vars_attrs.empty:
299
+ error_data['Attr Calib.'] = [
300
+ get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
301
+ get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
302
+ get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
 
303
  ]
304
  else:
305
+ error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
 
306
 
307
  # Present Value Calibration Errors
308
+ error_data['PV Calib.'] = [
309
+ get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
310
+ get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
311
+ get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
312
+ ]
 
 
 
 
 
 
 
 
313
 
314
  # Create Summary Plot
315
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
316
 
317
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
318
  summary_df.plot(kind='bar', ax=ax_summary, grid=True)
319
+ ax_summary.set_ylabel('Absolute Error Rate')
320
+ title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
321
+ ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
322
  ax_summary.tick_params(axis='x', rotation=0)
323
+ ax_summary.legend(title='Calibration Method')
324
  plt.tight_layout()
325
 
326
  buf_summary = io.BytesIO()
 
343
 
344
 
345
  def create_interface():
346
+ with gr.Blocks(title="Cluster Model Points Analysis") as demo:
347
  gr.Markdown("""
348
  # Cluster Model Points Analysis
349
 
 
381
 
382
  with gr.Tabs():
383
  with gr.TabItem("📊 Summary"):
384
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison")
385
 
386
  with gr.TabItem("💸 Cashflow Calibration"):
387
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
 
438
 
439
  # --- Action for Analyze Button ---
440
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
 
441
  files = [f1, f2, f3, f4, f5, f6, f7]
 
 
442
 
443
  file_paths = []
444
  for i, f_obj in enumerate(files):
445
  if f_obj is None:
446
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
 
447
  return [None] * len(get_all_output_components())
448
 
449
  # If f_obj is a Gradio FileData object (from direct upload)
 
456
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
457
  return [None] * len(get_all_output_components())
458
 
 
459
  results = process_files(*file_paths)
460
 
461
  if "error" in results:
 
462
  return [None] * len(get_all_output_components())
463
 
464
  return [
 
485
 
486
  # --- Action for Load Example Data Button ---
487
  def load_example_files():
 
488
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
489
  if missing_files:
490
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
491
+ return [None] * 7
492
 
493
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
494
  return [
 
507
  return demo
508
 
509
  if __name__ == "__main__":
 
510
  if not os.path.exists(EXAMPLE_DATA_DIR):
511
  os.makedirs(EXAMPLE_DATA_DIR)
512
  print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
 
 
 
 
 
513
  print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
514
 
 
515
  demo_app = create_interface()
516
  demo_app.launch()