alidenewade commited on
Commit
4355f45
·
verified ·
1 Parent(s): 4072b44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -105
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
@@ -22,16 +22,41 @@ EXAMPLE_FILES = {
22
  }
23
 
24
  class Clusters:
25
- def __init__(self, loc_vars):
26
- self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
- closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
30
- rep_ids.name = 'policy_id'
31
- rep_ids.index.name = 'cluster_id'
32
- self.rep_ids = rep_ids
33
 
34
- self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
 
 
35
 
36
  def agg_by_cluster(self, df, agg=None):
37
  """Aggregate columns by cluster"""
@@ -43,21 +68,46 @@ class Clusters:
43
 
44
  def extract_reps(self, df):
45
  """Extract the rows of representative policies"""
46
- temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
- temp.index.name = 'cluster_id'
48
- return temp.drop('policy_id', axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def extract_and_scale_reps(self, df, agg=None):
51
  """Extract and scale the rows of representative policies"""
 
52
  if agg:
53
- cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
- # Ensure mult has same index as extract_reps(df) for proper alignment
56
- extracted_df = self.extract_reps(df)
57
- mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
59
  else:
60
- return self.extract_reps(df).mul(self.policy_count, axis=0)
61
 
62
  def compare(self, df, agg=None):
63
  """Returns a multi-indexed Dataframe comparing actual and estimate"""
@@ -68,7 +118,6 @@ class Clusters:
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
- # Calculate actual values using specified aggregation
72
  actual_values = {}
73
  for col in df.columns:
74
  if agg.get(col, 'sum') == 'mean':
@@ -77,13 +126,19 @@ class Clusters:
77
  actual_values[col] = df[col].sum()
78
  actual = pd.Series(actual_values)
79
 
80
- # Calculate estimate values
81
  reps_unscaled = self.extract_reps(df)
82
  estimate_values = {}
83
 
84
- for col in df.columns:
 
 
 
 
 
 
 
 
85
  if agg.get(col, 'sum') == 'mean':
86
- # Weighted average for mean columns
87
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
  total_weight = self.policy_count.sum()
89
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
@@ -96,12 +151,14 @@ class Clusters:
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
- # Calculate error, handling division by zero
 
100
  error = np.where(actual != 0, estimate / actual - 1, 0)
101
 
102
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
 
 
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
  """Create cashflow comparison plots"""
107
  if not cfs_list or not cluster_obj or not titles:
@@ -110,7 +167,6 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
110
  if num_plots == 0:
111
  return None
112
 
113
- # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
116
 
@@ -119,12 +175,17 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
119
 
120
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
  if i < len(axes):
122
- comparison = cluster_obj.compare_total(df)
 
 
 
 
 
 
123
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
124
  axes[i].set_xlabel('Time')
125
  axes[i].set_ylabel('Value')
126
 
127
- # Hide any unused subplots
128
  for j in range(i + 1, len(axes)):
129
  fig.delaxes(axes[j])
130
 
@@ -139,7 +200,6 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
139
  def plot_scatter_comparison(df_compare_output, title):
140
  """Create scatter plot comparison from compare() output"""
141
  if df_compare_output is None or df_compare_output.empty:
142
- # Create a blank plot with a message
143
  fig, ax = plt.subplots(figsize=(12, 8))
144
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
145
  ax.set_title(title)
@@ -153,29 +213,28 @@ def plot_scatter_comparison(df_compare_output, title):
153
  fig, ax = plt.subplots(figsize=(12, 8))
154
 
155
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
157
- ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
158
  else:
159
  unique_levels = df_compare_output.index.get_level_values(1).unique()
160
  colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
161
 
162
  for item_level, color_val in zip(unique_levels, colors):
163
  subset = df_compare_output.xs(item_level, level=1)
164
- ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
- if len(unique_levels) > 1 and len(unique_levels) <= 10:
166
- ax.legend(title=df_compare_output.index.names[1])
167
 
168
  ax.set_xlabel('Actual')
169
  ax.set_ylabel('Estimate')
170
  ax.set_title(title)
171
  ax.grid(True)
172
 
173
- # Draw identity line
174
  lims = [
175
- np.min([ax.get_xlim(), ax.get_ylim()]),
176
- np.max([ax.get_xlim(), ax.get_ylim()]),
177
  ]
178
- if lims[0] != lims[1]:
179
  ax.plot(lims, lims, 'r-', linewidth=0.5)
180
  ax.set_xlim(lims)
181
  ax.set_ylim(lims)
@@ -187,28 +246,55 @@ def plot_scatter_comparison(df_compare_output, title):
187
  plt.close(fig)
188
  return img
189
 
190
-
 
191
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
192
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
193
  """Main processing function - now accepts file paths"""
194
  try:
195
- # Read uploaded files using paths
 
196
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
197
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
198
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
199
 
200
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
201
- # Ensure the correct columns are selected for pol_data
202
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
203
- if all(col in pol_data_full.columns for col in required_cols):
204
- pol_data = pol_data_full[required_cols]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  else:
206
- gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
207
- pol_data = pol_data_full
 
208
 
209
  pvs = pd.read_excel(pv_base_path, index_col=0)
210
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
211
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
 
 
 
 
 
 
 
212
 
213
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
@@ -217,8 +303,14 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
217
 
218
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
 
 
 
 
 
220
  # --- 1. Cashflow Calibration ---
221
- cluster_cfs = Clusters(cfs)
 
222
 
223
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
224
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
@@ -231,15 +323,22 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
231
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
232
 
233
  # --- 2. Policy Attribute Calibration ---
234
- # Standardize policy attributes
235
- if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
236
- loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
 
 
 
 
 
 
 
 
237
  else:
238
- gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
- loc_vars_attrs = pol_data
240
-
241
  if not loc_vars_attrs.empty:
242
- cluster_attrs = Clusters(loc_vars_attrs)
243
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
244
  results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
245
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
@@ -249,11 +348,12 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
249
  results['attr_total_cf_base'] = pd.DataFrame()
250
  results['attr_policy_attrs_total'] = pd.DataFrame()
251
  results['attr_total_pv_base'] = pd.DataFrame()
252
- results['attr_cashflow_plot'] = None
253
- results['attr_scatter_cashflows_base'] = None
 
254
 
255
  # --- 3. Present Value Calibration ---
256
- cluster_pvs = Clusters(pvs)
257
 
258
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
259
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
@@ -266,52 +366,47 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
266
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
267
 
268
  # --- Summary Comparison Plot Data ---
269
- # Error metric for key PV column or mean absolute error
270
-
271
  error_data = {}
272
 
273
- # Function to safely get error value
274
  def get_error_safe(compare_result, col_name=None):
275
- if compare_result.empty:
276
  return np.nan
277
  if col_name and col_name in compare_result.index:
278
  return abs(compare_result.loc[col_name, 'error'])
279
  else:
280
- # Use mean absolute error if specific column not found
281
  return abs(compare_result['error']).mean()
282
 
283
- # Determine key PV column (try common names)
284
  key_pv_col = None
 
 
 
 
285
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
286
- if potential_col in pvs.columns:
287
  key_pv_col = potential_col
288
  break
289
 
290
- # Cashflow Calibration Errors
291
  error_data['CF Calib.'] = [
292
- get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
293
- get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
294
- get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
295
  ]
296
 
297
- # Policy Attribute Calibration Errors
298
  if not loc_vars_attrs.empty:
299
  error_data['Attr Calib.'] = [
300
- get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
301
- get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
302
- get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
303
  ]
304
  else:
305
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
306
 
307
- # Present Value Calibration Errors
308
  error_data['PV Calib.'] = [
309
- get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
310
- get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
311
- get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
312
  ]
313
 
314
- # Create Summary Plot
315
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
316
 
317
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
@@ -335,13 +430,15 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
335
  gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
336
  return {"error": f"File not found: {e.filename}"}
337
  except KeyError as e:
338
- gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
339
- return {"error": f"Missing column: {e}"}
 
340
  except Exception as e:
341
- gr.Error(f"Error processing files: {str(e)}")
 
342
  return {"error": f"Error processing files: {str(e)}"}
343
 
344
-
345
  def create_interface():
346
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
347
  gr.Markdown("""
@@ -351,13 +448,15 @@ def create_interface():
351
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
352
 
353
  **Required Files (Excel .xlsx):**
354
- - Cashflows - Base Scenario
355
- - Cashflows - Lapse Stress (+50%)
356
- - Cashflows - Mortality Stress (+15%)
357
- - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
358
- - Present Values - Base Scenario
359
- - Present Values - Lapse Stress
360
- - Present Values - Mortality Stress
 
 
361
  """)
362
 
363
  with gr.Row():
@@ -404,7 +503,11 @@ def create_interface():
404
  attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
405
  attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
406
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
407
- attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
 
 
 
 
408
 
409
  with gr.TabItem("💰 Present Value Calibration"):
410
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
@@ -441,24 +544,28 @@ def create_interface():
441
  files = [f1, f2, f3, f4, f5, f6, f7]
442
 
443
  file_paths = []
 
 
 
 
 
 
 
 
444
  for i, f_obj in enumerate(files):
445
- if f_obj is None:
446
- gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
447
- return [None] * len(get_all_output_components())
448
-
449
- # If f_obj is a Gradio FileData object (from direct upload)
450
- if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
451
  file_paths.append(f_obj.name)
452
- # If f_obj is already a string path (from example load)
453
- elif isinstance(f_obj, str):
454
- file_paths.append(f_obj)
455
- else:
456
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
457
  return [None] * len(get_all_output_components())
458
 
459
  results = process_files(*file_paths)
460
 
461
- if "error" in results:
 
462
  return [None] * len(get_all_output_components())
463
 
464
  return [
@@ -485,18 +592,48 @@ def create_interface():
485
 
486
  # --- Action for Load Example Data Button ---
487
  def load_example_files():
488
- missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
489
- if missing_files:
490
- gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
491
- return [None] * 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
 
494
  return [
495
- EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
496
- EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
497
- EXAMPLE_FILES["pv_mort"]
 
 
 
 
498
  ]
499
 
 
500
  load_example_btn.click(
501
  load_example_files,
502
  inputs=[],
@@ -509,8 +646,30 @@ def create_interface():
509
  if __name__ == "__main__":
510
  if not os.path.exists(EXAMPLE_DATA_DIR):
511
  os.makedirs(EXAMPLE_DATA_DIR)
512
- print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
513
- print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  demo_app = create_interface()
516
  demo_app.launch()
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the provided snippet.
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
 
22
  }
23
 
24
  class Clusters:
25
+ def __init__(self, loc_vars_df): # Expecting a pandas DataFrame
26
+ # "Quantisize" by converting input DataFrame to float32 for KMeans.
27
+ # This reduces precision, potentially speeding up calculations and lowering memory.
28
+ # Results might have minor numerical differences compared to float64.
29
+ # Ensure data is a C-contiguous NumPy array.
30
+ if loc_vars_df.empty:
31
+ # Handle empty DataFrame case to avoid errors with .values or astype
32
+ # KMeans would fail anyway, but this prevents issues before that.
33
+ loc_vars_np_float32 = np.array([], dtype=np.float32).reshape(0, loc_vars_df.shape[1] if loc_vars_df.shape[1] > 0 else 0)
34
+ else:
35
+ loc_vars_np_float32 = np.ascontiguousarray(loc_vars_df.astype(np.float32).values)
36
+
37
+ # Initialize KMeans with algorithm="elkan" for potential speedup
38
+ # and fit on the float32 data.
39
+ self.kmeans = KMeans(
40
+ n_clusters=1000,
41
+ random_state=0,
42
+ n_init=10,
43
+ algorithm="elkan" # Added for speed optimization
44
+ ).fit(loc_vars_np_float32)
45
+
46
+ # cluster_centers_ will be float32 if fitted on float32 data.
47
+ # Pass the same float32 NumPy array for distance calculations.
48
+ closest, _ = pairwise_distances_argmin_min(
49
+ self.kmeans.cluster_centers_,
50
+ loc_vars_np_float32
51
+ )
52
 
53
+ self.rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
54
+ self.rep_ids.name = 'policy_id'
55
+ self.rep_ids.index.name = 'cluster_id'
 
56
 
57
+ # policy_count is based on the number of items in the input data.
58
+ # Use loc_vars_np_float32.shape[0] which is the number of rows.
59
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * loc_vars_np_float32.shape[0]}))['policy_count']
60
 
61
  def agg_by_cluster(self, df, agg=None):
62
  """Aggregate columns by cluster"""
 
68
 
69
  def extract_reps(self, df):
70
  """Extract the rows of representative policies"""
71
+ # Ensure policy_id in df is of the same type as self.rep_ids if it's not already the index
72
+ # Typically, df here will have 'policy_id' as its index as per original data.
73
+ # If df's index is not 'policy_id', ensure 'policy_id' column exists and has compatible type.
74
+ current_df_index_name = df.index.name
75
+ # If 'policy_id' is not the index, reset it. Otherwise, use the index.
76
+ if 'policy_id' not in df.columns and df.index.name != 'policy_id':
77
+ # This case should ideally not happen if inputs are consistent
78
+ # Forcing index to be named 'policy_id' if it's the policy identifier
79
+ df_indexed = df.copy()
80
+ if df_indexed.index.name is None: # Or some other logic to identify the policy_id column
81
+ gr.Warning("DataFrame passed to extract_reps has no index name, assuming index is policy_id.")
82
+ df_indexed.index.name = 'policy_id'
83
+
84
+ temp = pd.merge(self.rep_ids, df_indexed.reset_index(), how='left', on='policy_id')
85
+
86
+ elif 'policy_id' in df.columns and df.index.name == 'policy_id' and df.index.name in df.columns: # if policy_id is both index and a column
87
+ temp = pd.merge(self.rep_ids, df, how='left', on='policy_id') # Merge on column if available
88
+
89
+ elif df.index.name == 'policy_id':
90
+ temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
91
+
92
+ else: # 'policy_id' is a column, not the index
93
+ temp = pd.merge(self.rep_ids, df.reset_index(drop=df.index.name is None), how='left', on='policy_id')
94
+
95
+
96
+ temp.index.name = 'cluster_id' # The merge result's index is not cluster_id by default
97
+ temp = temp.set_index(self.rep_ids.index) # Set index to be cluster_id from self.rep_ids
98
+ return temp.drop('policy_id', axis=1, errors='ignore')
99
+
100
 
101
  def extract_and_scale_reps(self, df, agg=None):
102
  """Extract and scale the rows of representative policies"""
103
+ extracted_df = self.extract_reps(df)
104
  if agg:
105
+ cols = extracted_df.columns # Use columns from extracted_df
106
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
107
+ mult.index = extracted_df.index # Align index
 
 
108
  return extracted_df.mul(mult)
109
  else:
110
+ return extracted_df.mul(self.policy_count, axis=0)
111
 
112
  def compare(self, df, agg=None):
113
  """Returns a multi-indexed Dataframe comparing actual and estimate"""
 
118
  def compare_total(self, df, agg=None):
119
  """Aggregate df by columns"""
120
  if agg:
 
121
  actual_values = {}
122
  for col in df.columns:
123
  if agg.get(col, 'sum') == 'mean':
 
126
  actual_values[col] = df[col].sum()
127
  actual = pd.Series(actual_values)
128
 
 
129
  reps_unscaled = self.extract_reps(df)
130
  estimate_values = {}
131
 
132
+ for col in df.columns: # Iterate over original df columns to ensure all are covered
133
+ if col not in reps_unscaled.columns: # Column might not be in reps_unscaled if it was dropped or not selected
134
+ if agg.get(col, 'sum') == 'mean':
135
+ estimate_values[col] = np.nan # Or some other placeholder like 0, or actual.get(col, 0)
136
+ else:
137
+ estimate_values[col] = 0
138
+ gr.Warning(f"Column '{col}' not found in representative policies output for 'compare_total'. Estimate will be 0/NaN.")
139
+ continue
140
+
141
  if agg.get(col, 'sum') == 'mean':
 
142
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
143
  total_weight = self.policy_count.sum()
144
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
 
151
  actual = df.sum()
152
  estimate = self.extract_and_scale_reps(df).sum()
153
 
154
+ # Ensure alignment for error calculation
155
+ actual, estimate = actual.align(estimate, fill_value=0)
156
  error = np.where(actual != 0, estimate / actual - 1, 0)
157
 
158
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
159
 
160
 
161
+ # --- Plotting functions (plot_cashflows_comparison, plot_scatter_comparison) remain unchanged ---
162
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
163
  """Create cashflow comparison plots"""
164
  if not cfs_list or not cluster_obj or not titles:
 
167
  if num_plots == 0:
168
  return None
169
 
 
170
  cols = 2
171
  rows = (num_plots + cols - 1) // cols
172
 
 
175
 
176
  for i, (df, title) in enumerate(zip(cfs_list, titles)):
177
  if i < len(axes):
178
+ # Ensure df passed to compare_total is appropriate.
179
+ # If df has policy_id as index, it matches expectations of downstream functions in Clusters.
180
+ # If not, ensure policy_id is a column or handle appropriately.
181
+ if df.index.name != 'policy_id' and 'policy_id' not in df.columns:
182
+ gr.Warning(f"DataFrame for plot '{title}' does not have 'policy_id' as index or column. Results may be incorrect.")
183
+
184
+ comparison = cluster_obj.compare_total(df.set_index('policy_id') if 'policy_id' in df.columns and df.index.name != 'policy_id' else df)
185
  comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
186
  axes[i].set_xlabel('Time')
187
  axes[i].set_ylabel('Value')
188
 
 
189
  for j in range(i + 1, len(axes)):
190
  fig.delaxes(axes[j])
191
 
 
200
  def plot_scatter_comparison(df_compare_output, title):
201
  """Create scatter plot comparison from compare() output"""
202
  if df_compare_output is None or df_compare_output.empty:
 
203
  fig, ax = plt.subplots(figsize=(12, 8))
204
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
205
  ax.set_title(title)
 
213
  fig, ax = plt.subplots(figsize=(12, 8))
214
 
215
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
216
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
217
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
218
  else:
219
  unique_levels = df_compare_output.index.get_level_values(1).unique()
220
  colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
221
 
222
  for item_level, color_val in zip(unique_levels, colors):
223
  subset = df_compare_output.xs(item_level, level=1)
224
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
225
+ if len(unique_levels) > 1 and len(unique_levels) <= 20: # Increased legend item limit slightly
226
+ ax.legend(title=str(df_compare_output.index.names[1]))
227
 
228
  ax.set_xlabel('Actual')
229
  ax.set_ylabel('Estimate')
230
  ax.set_title(title)
231
  ax.grid(True)
232
 
 
233
  lims = [
234
+ np.nanmin([ax.get_xlim(), ax.get_ylim()]), # Use nanmin/nanmax
235
+ np.nanmax([ax.get_xlim(), ax.get_ylim()]),
236
  ]
237
+ if lims[0] != lims[1] and np.isfinite(lims[0]) and np.isfinite(lims[1]): # Check for valid limits
238
  ax.plot(lims, lims, 'r-', linewidth=0.5)
239
  ax.set_xlim(lims)
240
  ax.set_ylim(lims)
 
246
  plt.close(fig)
247
  return img
248
 
249
+ # --- Main processing function (process_files) ---
250
+ # Ensure DataFrames passed to Clusters methods have 'policy_id' as index if expected.
251
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
252
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
253
  """Main processing function - now accepts file paths"""
254
  try:
255
+ # Consider using engine='calamine' for faster Excel reading if available (pip install pandas[calamine])
256
+ # e.g., cfs = pd.read_excel(cashflow_base_path, index_col=0, engine='calamine')
257
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
258
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
259
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
260
 
261
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
262
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
263
+
264
+ # Ensure index is named 'policy_id' if it's not already named, assuming index is the policy identifier
265
+ for df in [cfs, cfs_lapse50, cfs_mort15, pol_data_full]:
266
+ if df.index.name is None:
267
+ df.index.name = 'policy_id'
268
+ if 'policy_id' not in df.columns and df.index.name == 'policy_id': # Add policy_id as column if its only an index
269
+ df.reset_index(inplace=True) # this makes policy_id a column
270
+ df.set_index('policy_id', inplace=True) # and keeps it as index
271
+
272
+ if all(col in pol_data_full.columns or col == pol_data_full.index.name for col in required_cols):
273
+ # If policy_id is index, it won't be in columns. Adjust selection.
274
+ cols_to_select = [col for col in required_cols if col in pol_data_full.columns]
275
+ if pol_data_full.index.name in required_cols and pol_data_full.index.name not in cols_to_select:
276
+ # This case is tricky; if an ID is part of required_cols and is the index.
277
+ # For simplicity, assume required_cols are actual data columns.
278
+ pass # Let it proceed, it might be handled by selection or error later.
279
+
280
+ pol_data = pol_data_full[cols_to_select].copy() # Use .copy() to avoid SettingWithCopyWarning
281
+ # If 'policy_id' was the index and required, it's implicitly handled or needs specific logic.
282
+ # For K-Means, policy_id itself is usually not a feature.
283
  else:
284
+ missing_req_cols = [col for col in required_cols if col not in pol_data_full.columns and col != pol_data_full.index.name]
285
+ gr.Warning(f"Policy data might be missing required columns: {missing_req_cols}. Found: {pol_data_full.columns.tolist()}")
286
+ pol_data = pol_data_full # Fallback, but ensure it's numeric for clustering/scaling
287
 
288
  pvs = pd.read_excel(pv_base_path, index_col=0)
289
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
290
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
291
+
292
+ for df in [pvs, pvs_lapse50, pvs_mort15]:
293
+ if df.index.name is None:
294
+ df.index.name = 'policy_id'
295
+ if 'policy_id' not in df.columns and df.index.name == 'policy_id':
296
+ df.reset_index(inplace=True)
297
+ df.set_index('policy_id', inplace=True)
298
 
299
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
300
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
 
303
 
304
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
305
 
306
+ # DataFrames passed to Clusters should be policy_id indexed for .values to exclude it.
307
+ # Or, select only feature columns before passing.
308
+ # The Clusters class now expects a DataFrame and will use .values, so pass only feature columns.
309
+ # If index is policy_id, df.values will not include it. This is good.
310
+
311
  # --- 1. Cashflow Calibration ---
312
+ # Ensure 'cfs' DataFrame does not include 'policy_id' when .values is called in Clusters
313
+ cluster_cfs = Clusters(cfs.reset_index().set_index('policy_id')) # Pass with policy_id as index
314
 
315
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
316
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
323
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
324
 
325
  # --- 2. Policy Attribute Calibration ---
326
+ loc_vars_attrs = pd.DataFrame() # Initialize
327
+ if not pol_data.empty:
328
+ # Ensure pol_data is purely numeric for scaling and KMeans
329
+ numeric_pol_data = pol_data.select_dtypes(include=np.number)
330
+ if not numeric_pol_data.empty and not (numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True) == 0).all():
331
+ loc_vars_attrs = (numeric_pol_data - numeric_pol_data.min(numeric_only=True)) / \
332
+ (numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True))
333
+ loc_vars_attrs.index = numeric_pol_data.index # Preserve index
334
+ else:
335
+ gr.Warning("Policy data for attribute calibration is empty, non-numeric, or has no variance. Skipping attribute calibration content.")
336
+ loc_vars_attrs = numeric_pol_data # or an empty DataFrame with original index
337
  else:
338
+ gr.Warning("Policy data is empty. Skipping attribute calibration content.")
339
+
 
340
  if not loc_vars_attrs.empty:
341
+ cluster_attrs = Clusters(loc_vars_attrs.reset_index().set_index('policy_id')) # Pass with policy_id as index
342
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
343
  results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
344
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
 
348
  results['attr_total_cf_base'] = pd.DataFrame()
349
  results['attr_policy_attrs_total'] = pd.DataFrame()
350
  results['attr_total_pv_base'] = pd.DataFrame()
351
+ results['attr_cashflow_plot'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data') # Empty plot
352
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data')
353
+
354
 
355
  # --- 3. Present Value Calibration ---
356
+ cluster_pvs = Clusters(pvs.reset_index().set_index('policy_id')) # Pass with policy_id as index
357
 
358
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
359
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
366
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
367
 
368
  # --- Summary Comparison Plot Data ---
 
 
369
  error_data = {}
370
 
 
371
  def get_error_safe(compare_result, col_name=None):
372
+ if compare_result is None or compare_result.empty or 'error' not in compare_result.columns: # Check if None
373
  return np.nan
374
  if col_name and col_name in compare_result.index:
375
  return abs(compare_result.loc[col_name, 'error'])
376
  else:
 
377
  return abs(compare_result['error']).mean()
378
 
 
379
  key_pv_col = None
380
+ # Use pvs.columns (which should be only feature columns after reset_index().set_index())
381
+ # Or, use the original pvs DataFrame if it's guaranteed to have the PV_NetCF column.
382
+ # For safety, check in the original pvs DataFrame which has not been stripped of columns.
383
+ original_pvs_cols = pd.read_excel(pv_base_path).columns # Quick read just for columns
384
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
385
+ if potential_col in original_pvs_cols: # Check against original columns
386
  key_pv_col = potential_col
387
  break
388
 
 
389
  error_data['CF Calib.'] = [
390
+ get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
391
+ get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
392
+ get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
393
  ]
394
 
 
395
  if not loc_vars_attrs.empty:
396
  error_data['Attr Calib.'] = [
397
+ get_error_safe(results.get('attr_total_pv_base'), key_pv_col), # This was pvs, should be fine
398
+ get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), # Re-calculate for pvs_lapse50
399
+ get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) # Re-calculate for pvs_mort15
400
  ]
401
  else:
402
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
403
 
 
404
  error_data['PV Calib.'] = [
405
+ get_error_safe(results.get('pv_total_pv_base'), key_pv_col),
406
+ get_error_safe(results.get('pv_total_pv_lapse'), key_pv_col),
407
+ get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
408
  ]
409
 
 
410
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
411
 
412
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
 
430
  gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
431
  return {"error": f"File not found: {e.filename}"}
432
  except KeyError as e:
433
+ # Check if the KeyError is from trying to access a column that became an index
434
+ gr.Error(f"A required column or index is missing or misnamed: {e}. Please check data format and ensure 'policy_id' is correctly handled as index for feature dataframes.")
435
+ return {"error": f"Missing column/index: {e}"}
436
  except Exception as e:
437
+ import traceback
438
+ gr.Error(f"Error processing files: {str(e)}. Trace: {traceback.format_exc()}")
439
  return {"error": f"Error processing files: {str(e)}"}
440
 
441
+ # --- Gradio interface creation (create_interface, etc.) remains unchanged ---
442
  def create_interface():
443
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
444
  gr.Markdown("""
 
448
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
449
 
450
  **Required Files (Excel .xlsx):**
451
+ - Cashflows - Base Scenario (index = policy_id, columns = time periods)
452
+ - Cashflows - Lapse Stress (+50%) (index = policy_id)
453
+ - Cashflows - Mortality Stress (+15%) (index = policy_id)
454
+ - Policy Data (index = policy_id, including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' as columns)
455
+ - Present Values - Base Scenario (index = policy_id, columns = PV components like 'PV_NetCF')
456
+ - Present Values - Lapse Stress (index = policy_id)
457
+ - Present Values - Mortality Stress (index = policy_id)
458
+
459
+ *Note: Ensure 'policy_id' is the index for all input files for correct processing.*
460
  """)
461
 
462
  with gr.Row():
 
503
  attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
504
  attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
505
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
506
+ with gr.Row(): # Changed to Row for consistency
507
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
508
+ # Added placeholders for other scenarios if they were intended
509
+ # attr_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
510
+ # attr_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
511
 
512
  with gr.TabItem("💰 Present Value Calibration"):
513
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
 
544
  files = [f1, f2, f3, f4, f5, f6, f7]
545
 
546
  file_paths = []
547
+ # Check if any FileData object is None (no file uploaded for a slot)
548
+ if any(f_obj is None for f_obj in files):
549
+ # Attempt to load from EXAMPLE_FILES if any input is missing
550
+ # This logic might be complex if mixing examples and uploads.
551
+ # For now, strict: all files must be present.
552
+ gr.Error("Missing file input for one or more fields. Please upload all required files or load the complete example dataset.")
553
+ return [None] * len(get_all_output_components())
554
+
555
  for i, f_obj in enumerate(files):
556
+ # f_obj is TempFilePath (older Gradio) or FileData (newer) or str (from example load)
557
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Gradio FileData or similar
 
 
 
 
558
  file_paths.append(f_obj.name)
559
+ elif isinstance(f_obj, str): # Path from example load
560
+ file_paths.append(f_obj)
561
+ else: # Should not happen if inputs are Files or paths
 
562
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
563
  return [None] * len(get_all_output_components())
564
 
565
  results = process_files(*file_paths)
566
 
567
+ if "error" in results : # Check if process_files returned an error dict
568
+ # Error already shown by gr.Error in process_files
569
  return [None] * len(get_all_output_components())
570
 
571
  return [
 
592
 
593
  # --- Action for Load Example Data Button ---
594
  def load_example_files():
595
+ # Create dummy example files if they don't exist for demonstration if needed
596
+ # For this exercise, we assume they exist or user is warned.
597
+ os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True) # Ensure dir exists
598
+
599
+ missing_files = []
600
+ for key, fp in EXAMPLE_FILES.items():
601
+ if not os.path.exists(fp):
602
+ missing_files.append(fp)
603
+ # Create a minimal dummy Excel file if it's missing
604
+ try:
605
+ dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
606
+ if "cashflow" in key or "pv" in key: # Time series like
607
+ dummy_df_data = {'policy_id': [1,2,3], '0': [1,2,3], '1': [4,5,6]}
608
+ elif "policy_data" in key:
609
+ dummy_df_data = {'policy_id': [1,2,3], 'age_at_entry': [20,30,40], 'policy_term': [10,20,15],
610
+ 'sum_assured': [1000,2000,1500], 'duration_mth': [5,10,7]}
611
+
612
+ dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
613
+ dummy_df.to_excel(fp)
614
+ gr.Warning(f"Example file '{fp}' was missing and a dummy file has been created. Results may not be meaningful.")
615
+ except Exception as e:
616
+ gr.Warning(f"Could not create dummy file for {fp}: {e}")
617
+
618
+
619
+ if missing_files and not all(os.path.exists(fp) for fp in EXAMPLE_FILES.values()): # Re-check after dummy creation attempt
620
+ # If still missing after trying to create dummies
621
+ gr.Error(f"Critical example data files are missing from '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist or check permissions.")
622
+ return [None] * 7 # Return None for all file inputs
623
 
624
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
625
+ # Return the string paths for the file components
626
  return [
627
+ gr.File(value=EXAMPLE_FILES["cashflow_base"], Labeled_input=cashflow_base_input.label),
628
+ gr.File(value=EXAMPLE_FILES["cashflow_lapse"], Labeled_input=cashflow_lapse_input.label),
629
+ gr.File(value=EXAMPLE_FILES["cashflow_mort"], Labeled_input=cashflow_mort_input.label),
630
+ gr.File(value=EXAMPLE_FILES["policy_data"], Labeled_input=policy_data_input.label),
631
+ gr.File(value=EXAMPLE_FILES["pv_base"], Labeled_input=pv_base_input.label),
632
+ gr.File(value=EXAMPLE_FILES["pv_lapse"], Labeled_input=pv_lapse_input.label),
633
+ gr.File(value=EXAMPLE_FILES["pv_mort"], Labeled_input=pv_mort_input.label)
634
  ]
635
 
636
+
637
  load_example_btn.click(
638
  load_example_files,
639
  inputs=[],
 
646
  if __name__ == "__main__":
647
  if not os.path.exists(EXAMPLE_DATA_DIR):
648
  os.makedirs(EXAMPLE_DATA_DIR)
649
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or they will be generated as dummies.")
650
+
651
+ # Simple check and dummy file creation for example data if not present
652
+ for key, fp in EXAMPLE_FILES.items():
653
+ if not os.path.exists(fp):
654
+ print(f"Example file {fp} not found. Attempting to create a dummy file.")
655
+ try:
656
+ dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
657
+ if "cashflow" in key or "pv" in key:
658
+ dummy_df_data = {f'{i}':np.random.rand(3) for i in range(10)} # 10 time periods
659
+ dummy_df_data['policy_id'] = [f'P{j}' for j in range(3)]
660
+ elif "policy_data" in key:
661
+ dummy_df_data = {'policy_id': [f'P{j}' for j in range(3)],
662
+ 'age_at_entry': np.random.randint(20, 50, 3),
663
+ 'policy_term': np.random.randint(10, 30, 3),
664
+ 'sum_assured': np.random.randint(10000, 50000, 3),
665
+ 'duration_mth': np.random.randint(1, 120, 3)}
666
+
667
+ dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
668
+ dummy_df.to_excel(fp)
669
+ print(f"Dummy file for '{fp}' created.")
670
+ except Exception as e:
671
+ print(f"Could not create dummy file for {fp}: {e}")
672
+
673
 
674
  demo_app = create_interface()
675
  demo_app.launch()