alidenewade commited on
Commit
9846b45
·
verified ·
1 Parent(s): e5a1f5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -89
app.py CHANGED
@@ -2,11 +2,11 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min
6
- import seaborn as sns
7
  import matplotlib.pyplot as plt
 
8
  import io
9
- import os
10
  from PIL import Image
11
 
12
  # Define the paths for example data
@@ -23,10 +23,10 @@ EXAMPLE_FILES = {
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
- self.kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
- closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
 
29
- rep_ids = pd.Series(data=(closest+1))
30
  rep_ids.name = 'policy_id'
31
  rep_ids.index.name = 'cluster_id'
32
  self.rep_ids = rep_ids
@@ -34,6 +34,7 @@ class Clusters:
34
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
 
36
  def agg_by_cluster(self, df, agg=None):
 
37
  temp = df.copy()
38
  temp['cluster_id'] = self.kmeans.labels_
39
  temp = temp.set_index('cluster_id')
@@ -41,14 +42,17 @@ class Clusters:
41
  return temp.groupby(temp.index).agg(agg)
42
 
43
  def extract_reps(self, df):
 
44
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
45
  temp.index.name = 'cluster_id'
46
  return temp.drop('policy_id', axis=1)
47
 
48
  def extract_and_scale_reps(self, df, agg=None):
 
49
  if agg:
50
  cols = df.columns
51
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
 
52
  extracted_df = self.extract_reps(df)
53
  mult.index = extracted_df.index
54
  return extracted_df.mul(mult)
@@ -56,47 +60,57 @@ class Clusters:
56
  return self.extract_reps(df).mul(self.policy_count, axis=0)
57
 
58
  def compare(self, df, agg=None):
 
59
  source = self.agg_by_cluster(df, agg)
60
  target = self.extract_and_scale_reps(df, agg)
61
  return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
62
 
63
  def compare_total(self, df, agg=None):
 
64
  if agg:
 
65
  actual_values = {}
66
  for col in df.columns:
67
  if agg.get(col, 'sum') == 'mean':
68
  actual_values[col] = df[col].mean()
69
- else:
70
  actual_values[col] = df[col].sum()
71
  actual = pd.Series(actual_values)
72
 
 
73
  reps_unscaled = self.extract_reps(df)
74
  estimate_values = {}
75
 
76
  for col in df.columns:
77
  if agg.get(col, 'sum') == 'mean':
 
78
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
79
  total_weight = self.policy_count.sum()
80
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
81
- else:
82
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
83
 
84
  estimate = pd.Series(estimate_values)
85
- else:
 
86
  actual = df.sum()
87
  estimate = self.extract_and_scale_reps(df).sum()
88
 
 
89
  error = np.where(actual != 0, estimate / actual - 1, 0)
 
90
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
91
 
 
92
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
 
93
  if not cfs_list or not cluster_obj or not titles:
94
  return None
95
-
96
  num_plots = len(cfs_list)
97
  if num_plots == 0:
98
  return None
99
 
 
100
  cols = 2
101
  rows = (num_plots + cols - 1) // cols
102
 
@@ -106,17 +120,11 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
107
  if i < len(axes):
108
  comparison = cluster_obj.compare_total(df)
109
-
110
- # Plot using seaborn lineplot for cleaner aesthetics
111
- data_to_plot = comparison[['actual', 'estimate']].reset_index()
112
- data_melted = data_to_plot.melt(id_vars='index', var_name='Type', value_name='Value')
113
-
114
- sns.lineplot(data=data_melted, x='index', y='Value', hue='Type', ax=axes[i])
115
- axes[i].set_title(title)
116
  axes[i].set_xlabel('Time')
117
  axes[i].set_ylabel('Value')
118
- axes[i].grid(True)
119
 
 
120
  for j in range(i + 1, len(axes)):
121
  fig.delaxes(axes[j])
122
 
@@ -129,7 +137,9 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
129
  return img
130
 
131
  def plot_scatter_comparison(df_compare_output, title):
 
132
  if df_compare_output is None or df_compare_output.empty:
 
133
  fig, ax = plt.subplots(figsize=(12, 8))
134
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
135
  ax.set_title(title)
@@ -143,20 +153,17 @@ def plot_scatter_comparison(df_compare_output, title):
143
  fig, ax = plt.subplots(figsize=(12, 8))
144
 
145
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
146
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
147
- sns.scatterplot(x='actual', y='estimate', data=df_compare_output, s=9, alpha=0.6, ax=ax)
148
  else:
149
- # Prepare data for seaborn
150
- plot_data = df_compare_output.reset_index()
151
- level_1_name = df_compare_output.index.names[1]
152
-
153
  unique_levels = df_compare_output.index.get_level_values(1).unique()
 
154
 
 
 
 
155
  if len(unique_levels) > 1 and len(unique_levels) <= 10:
156
- sns.scatterplot(x='actual', y='estimate', hue=level_1_name,
157
- data=plot_data, s=9, alpha=0.6, ax=ax)
158
- else:
159
- sns.scatterplot(x='actual', y='estimate', data=plot_data, s=9, alpha=0.6, ax=ax)
160
 
161
  ax.set_xlabel('Actual')
162
  ax.set_ylabel('Estimate')
@@ -169,9 +176,9 @@ def plot_scatter_comparison(df_compare_output, title):
169
  np.max([ax.get_xlim(), ax.get_ylim()]),
170
  ]
171
  if lims[0] != lims[1]:
172
- ax.plot(lims, lims, 'r-', linewidth=0.5)
173
- ax.set_xlim(lims)
174
- ax.set_ylim(lims)
175
 
176
  buf = io.BytesIO()
177
  plt.savefig(buf, format='png', dpi=100)
@@ -180,15 +187,18 @@ def plot_scatter_comparison(df_compare_output, title):
180
  plt.close(fig)
181
  return img
182
 
 
183
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
184
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
 
185
  try:
186
- # Read files
187
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
188
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
189
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
190
 
191
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
192
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
193
  if all(col in pol_data_full.columns for col in required_cols):
194
  pol_data = pol_data_full[required_cols]
@@ -204,90 +214,104 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
204
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
205
 
206
  results = {}
 
207
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
208
 
209
- # Cashflow Calibration
210
  cluster_cfs = Clusters(cfs)
211
- results.update({
212
- 'cf_total_base_table': cluster_cfs.compare_total(cfs),
213
- 'cf_policy_attrs_total': cluster_cfs.compare_total(pol_data, agg=mean_attrs),
214
- 'cf_pv_total_base': cluster_cfs.compare_total(pvs),
215
- 'cf_pv_total_lapse': cluster_cfs.compare_total(pvs_lapse50),
216
- 'cf_pv_total_mort': cluster_cfs.compare_total(pvs_mort15),
217
- 'cf_cashflow_plot': plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles),
218
- 'cf_scatter_cashflows_base': plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
219
- })
220
-
221
- # Policy Attribute Calibration
 
 
222
  if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
223
- loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
224
  else:
225
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
226
  loc_vars_attrs = pol_data
227
 
228
  if not loc_vars_attrs.empty:
229
  cluster_attrs = Clusters(loc_vars_attrs)
230
- results.update({
231
- 'attr_total_cf_base': cluster_attrs.compare_total(cfs),
232
- 'attr_policy_attrs_total': cluster_attrs.compare_total(pol_data, agg=mean_attrs),
233
- 'attr_total_pv_base': cluster_attrs.compare_total(pvs),
234
- 'attr_cashflow_plot': plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles),
235
- 'attr_scatter_cashflows_base': plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
236
- })
237
  else:
238
- results.update({
239
- 'attr_total_cf_base': pd.DataFrame(),
240
- 'attr_policy_attrs_total': pd.DataFrame(),
241
- 'attr_total_pv_base': pd.DataFrame(),
242
- 'attr_cashflow_plot': None,
243
- 'attr_scatter_cashflows_base': None
244
- })
245
-
246
- # Present Value Calibration
247
  cluster_pvs = Clusters(pvs)
248
- results.update({
249
- 'pv_total_cf_base': cluster_pvs.compare_total(cfs),
250
- 'pv_policy_attrs_total': cluster_pvs.compare_total(pol_data, agg=mean_attrs),
251
- 'pv_total_pv_base': cluster_pvs.compare_total(pvs),
252
- 'pv_total_pv_lapse': cluster_pvs.compare_total(pvs_lapse50),
253
- 'pv_total_pv_mort': cluster_pvs.compare_total(pvs_mort15),
254
- 'pv_cashflow_plot': plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles),
255
- 'pv_scatter_pvs_base': plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
256
- })
257
-
258
- # Summary Comparison Plot
 
 
 
 
 
 
259
  def get_error_safe(compare_result, col_name=None):
260
  if compare_result.empty:
261
  return np.nan
262
  if col_name and col_name in compare_result.index:
263
  return abs(compare_result.loc[col_name, 'error'])
264
  else:
 
265
  return abs(compare_result['error']).mean()
266
 
 
267
  key_pv_col = None
268
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
269
  if potential_col in pvs.columns:
270
  key_pv_col = potential_col
271
  break
272
 
273
- error_data = {
274
- 'CF Calib.': [
275
- get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
276
- get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
277
- get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
278
- ],
279
- 'Attr Calib.': [
280
- get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col) if not loc_vars_attrs.empty else np.nan,
281
- get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col) if not loc_vars_attrs.empty else np.nan,
282
- get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) if not loc_vars_attrs.empty else np.nan
283
- ] if not loc_vars_attrs.empty else [np.nan, np.nan, np.nan],
284
- 'PV Calib.': [
285
- get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
286
- get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
287
- get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
288
  ]
289
- }
 
 
 
 
 
 
 
 
290
 
 
291
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
292
 
293
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
@@ -317,6 +341,7 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
317
  gr.Error(f"Error processing files: {str(e)}")
318
  return {"error": f"Error processing files: {str(e)}"}
319
 
 
320
  def create_interface():
321
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
322
  gr.Markdown("""
@@ -394,30 +419,37 @@ def create_interface():
394
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
395
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
396
 
 
397
  def get_all_output_components():
398
  return [
399
  summary_plot_output,
 
400
  cf_total_base_table_out, cf_policy_attrs_total_out,
401
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
402
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
 
403
  attr_total_cf_base_out, attr_policy_attrs_total_out,
404
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
 
405
  pv_total_cf_base_out, pv_policy_attrs_total_out,
406
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
407
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
408
  ]
409
 
 
410
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
411
  files = [f1, f2, f3, f4, f5, f6, f7]
412
- file_paths = []
413
 
 
414
  for i, f_obj in enumerate(files):
415
  if f_obj is None:
416
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
417
  return [None] * len(get_all_output_components())
418
 
 
419
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
420
  file_paths.append(f_obj.name)
 
421
  elif isinstance(f_obj, str):
422
  file_paths.append(f_obj)
423
  else:
@@ -431,11 +463,14 @@ def create_interface():
431
 
432
  return [
433
  results.get('summary_plot'),
 
434
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
435
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
436
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
 
437
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
438
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
 
439
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
440
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
441
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
@@ -448,6 +483,7 @@ def create_interface():
448
  outputs=get_all_output_components()
449
  )
450
 
 
451
  def load_example_files():
452
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
453
  if missing_files:
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
 
6
  import matplotlib.pyplot as plt
7
+ import matplotlib.cm
8
  import io
9
+ import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
 
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
+ self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
+ closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
28
 
29
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
30
  rep_ids.name = 'policy_id'
31
  rep_ids.index.name = 'cluster_id'
32
  self.rep_ids = rep_ids
 
34
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
 
36
  def agg_by_cluster(self, df, agg=None):
37
+ """Aggregate columns by cluster"""
38
  temp = df.copy()
39
  temp['cluster_id'] = self.kmeans.labels_
40
  temp = temp.set_index('cluster_id')
 
42
  return temp.groupby(temp.index).agg(agg)
43
 
44
  def extract_reps(self, df):
45
+ """Extract the rows of representative policies"""
46
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
  temp.index.name = 'cluster_id'
48
  return temp.drop('policy_id', axis=1)
49
 
50
  def extract_and_scale_reps(self, df, agg=None):
51
+ """Extract and scale the rows of representative policies"""
52
  if agg:
53
  cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+ # Ensure mult has same index as extract_reps(df) for proper alignment
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
 
60
  return self.extract_reps(df).mul(self.policy_count, axis=0)
61
 
62
  def compare(self, df, agg=None):
63
+ """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
  source = self.agg_by_cluster(df, agg)
65
  target = self.extract_and_scale_reps(df, agg)
66
  return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
67
 
68
  def compare_total(self, df, agg=None):
69
+ """Aggregate df by columns"""
70
  if agg:
71
+ # Calculate actual values using specified aggregation
72
  actual_values = {}
73
  for col in df.columns:
74
  if agg.get(col, 'sum') == 'mean':
75
  actual_values[col] = df[col].mean()
76
+ else: # sum
77
  actual_values[col] = df[col].sum()
78
  actual = pd.Series(actual_values)
79
 
80
+ # Calculate estimate values
81
  reps_unscaled = self.extract_reps(df)
82
  estimate_values = {}
83
 
84
  for col in df.columns:
85
  if agg.get(col, 'sum') == 'mean':
86
+ # Weighted average for mean columns
87
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
  total_weight = self.policy_count.sum()
89
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
90
+ else: # sum
91
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
92
 
93
  estimate = pd.Series(estimate_values)
94
+
95
+ else: # Original logic if no agg is specified (all sum)
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
+ # Calculate error, handling division by zero
100
  error = np.where(actual != 0, estimate / actual - 1, 0)
101
+
102
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
+
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
+ """Create cashflow comparison plots"""
107
  if not cfs_list or not cluster_obj or not titles:
108
  return None
 
109
  num_plots = len(cfs_list)
110
  if num_plots == 0:
111
  return None
112
 
113
+ # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
116
 
 
120
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
  if i < len(axes):
122
  comparison = cluster_obj.compare_total(df)
123
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
 
 
 
 
 
 
124
  axes[i].set_xlabel('Time')
125
  axes[i].set_ylabel('Value')
 
126
 
127
+ # Hide any unused subplots
128
  for j in range(i + 1, len(axes)):
129
  fig.delaxes(axes[j])
130
 
 
137
  return img
138
 
139
  def plot_scatter_comparison(df_compare_output, title):
140
+ """Create scatter plot comparison from compare() output"""
141
  if df_compare_output is None or df_compare_output.empty:
142
+ # Create a blank plot with a message
143
  fig, ax = plt.subplots(figsize=(12, 8))
144
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
145
  ax.set_title(title)
 
153
  fig, ax = plt.subplots(figsize=(12, 8))
154
 
155
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
157
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
158
  else:
 
 
 
 
159
  unique_levels = df_compare_output.index.get_level_values(1).unique()
160
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
161
 
162
+ for item_level, color_val in zip(unique_levels, colors):
163
+ subset = df_compare_output.xs(item_level, level=1)
164
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
  if len(unique_levels) > 1 and len(unique_levels) <= 10:
166
+ ax.legend(title=df_compare_output.index.names[1])
 
 
 
167
 
168
  ax.set_xlabel('Actual')
169
  ax.set_ylabel('Estimate')
 
176
  np.max([ax.get_xlim(), ax.get_ylim()]),
177
  ]
178
  if lims[0] != lims[1]:
179
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
180
+ ax.set_xlim(lims)
181
+ ax.set_ylim(lims)
182
 
183
  buf = io.BytesIO()
184
  plt.savefig(buf, format='png', dpi=100)
 
187
  plt.close(fig)
188
  return img
189
 
190
+
191
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
192
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
193
+ """Main processing function - now accepts file paths"""
194
  try:
195
+ # Read uploaded files using paths
196
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
197
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
198
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
199
 
200
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
201
+ # Ensure the correct columns are selected for pol_data
202
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
203
  if all(col in pol_data_full.columns for col in required_cols):
204
  pol_data = pol_data_full[required_cols]
 
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
215
 
216
  results = {}
217
+
218
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
220
+ # --- 1. Cashflow Calibration ---
221
  cluster_cfs = Clusters(cfs)
222
+
223
+ results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
224
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
225
+
226
+ results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
227
+ results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
228
+ results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
229
+
230
+ results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
231
+ results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
232
+
233
+ # --- 2. Policy Attribute Calibration ---
234
+ # Standardize policy attributes
235
  if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
236
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
237
  else:
238
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
  loc_vars_attrs = pol_data
240
 
241
  if not loc_vars_attrs.empty:
242
  cluster_attrs = Clusters(loc_vars_attrs)
243
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
244
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
245
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
246
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
247
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
 
 
248
  else:
249
+ results['attr_total_cf_base'] = pd.DataFrame()
250
+ results['attr_policy_attrs_total'] = pd.DataFrame()
251
+ results['attr_total_pv_base'] = pd.DataFrame()
252
+ results['attr_cashflow_plot'] = None
253
+ results['attr_scatter_cashflows_base'] = None
254
+
255
+ # --- 3. Present Value Calibration ---
 
 
256
  cluster_pvs = Clusters(pvs)
257
+
258
+ results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
259
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
260
+
261
+ results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
262
+ results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
263
+ results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
264
+
265
+ results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
266
+ results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
267
+
268
+ # --- Summary Comparison Plot Data ---
269
+ # Error metric for key PV column or mean absolute error
270
+
271
+ error_data = {}
272
+
273
+ # Function to safely get error value
274
  def get_error_safe(compare_result, col_name=None):
275
  if compare_result.empty:
276
  return np.nan
277
  if col_name and col_name in compare_result.index:
278
  return abs(compare_result.loc[col_name, 'error'])
279
  else:
280
+ # Use mean absolute error if specific column not found
281
  return abs(compare_result['error']).mean()
282
 
283
+ # Determine key PV column (try common names)
284
  key_pv_col = None
285
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
286
  if potential_col in pvs.columns:
287
  key_pv_col = potential_col
288
  break
289
 
290
+ # Cashflow Calibration Errors
291
+ error_data['CF Calib.'] = [
292
+ get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
293
+ get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
294
+ get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
295
+ ]
296
+
297
+ # Policy Attribute Calibration Errors
298
+ if not loc_vars_attrs.empty:
299
+ error_data['Attr Calib.'] = [
300
+ get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
301
+ get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
302
+ get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
 
 
303
  ]
304
+ else:
305
+ error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
306
+
307
+ # Present Value Calibration Errors
308
+ error_data['PV Calib.'] = [
309
+ get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
310
+ get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
311
+ get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
312
+ ]
313
 
314
+ # Create Summary Plot
315
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
316
 
317
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
 
341
  gr.Error(f"Error processing files: {str(e)}")
342
  return {"error": f"Error processing files: {str(e)}"}
343
 
344
+
345
  def create_interface():
346
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
347
  gr.Markdown("""
 
419
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
420
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
421
 
422
+ # --- Helper function to prepare outputs ---
423
  def get_all_output_components():
424
  return [
425
  summary_plot_output,
426
+ # Cashflow Calib Outputs
427
  cf_total_base_table_out, cf_policy_attrs_total_out,
428
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
429
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
430
+ # Attribute Calib Outputs
431
  attr_total_cf_base_out, attr_policy_attrs_total_out,
432
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
433
+ # PV Calib Outputs
434
  pv_total_cf_base_out, pv_policy_attrs_total_out,
435
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
436
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
437
  ]
438
 
439
+ # --- Action for Analyze Button ---
440
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
441
  files = [f1, f2, f3, f4, f5, f6, f7]
 
442
 
443
+ file_paths = []
444
  for i, f_obj in enumerate(files):
445
  if f_obj is None:
446
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
447
  return [None] * len(get_all_output_components())
448
 
449
+ # If f_obj is a Gradio FileData object (from direct upload)
450
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
451
  file_paths.append(f_obj.name)
452
+ # If f_obj is already a string path (from example load)
453
  elif isinstance(f_obj, str):
454
  file_paths.append(f_obj)
455
  else:
 
463
 
464
  return [
465
  results.get('summary_plot'),
466
+ # CF Calib
467
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
468
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
469
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
470
+ # Attr Calib
471
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
472
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
473
+ # PV Calib
474
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
475
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
476
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
 
483
  outputs=get_all_output_components()
484
  )
485
 
486
+ # --- Action for Load Example Data Button ---
487
  def load_example_files():
488
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
489
  if missing_files: