alidenewade commited on
Commit
91e876f
·
verified ·
1 Parent(s): ba285ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -32
app.py CHANGED
@@ -5,15 +5,17 @@ from sklearn.cluster import KMeans
5
  from sklearn.metrics import r2_score, pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
7
  import io
 
8
 
9
  def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
10
  """
11
  Performs cluster analysis for actuarial model point selection.
 
12
 
13
  Args:
14
- policy_file: Gradio File object for policy data (Excel).
15
- cashflow_file: Gradio File object for cashflow data (Excel).
16
- pv_file: Gradio File object for present value data (Excel).
17
  num_clusters: Number of clusters (model points) to generate.
18
 
19
  Returns:
@@ -35,19 +37,30 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
35
  print(f"PV file received: {pv_file.name if pv_file else 'None'}")
36
  print("="*50 + "\n")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # 1. Basic checks and file reading
39
  try:
40
- if policy_file is None or cashflow_file is None or pv_file is None:
41
- missing_files = []
42
- if policy_file is None: missing_files.append("Policy Data")
43
- if cashflow_file is None: missing_files.append("Cashflow Data")
44
- if pv_file is None: missing_files.append("Present Value Data")
45
- raise ValueError(f"Missing required input file(s): {', '.join(missing_files)}. Please upload all files.")
46
-
47
- policy_df = pd.read_excel(policy_file.name)
48
  # index_col=0 is crucial. Ensure the first column contains unique policy identifiers.
49
- cashflow_df = pd.read_excel(cashflow_file.name, index_col=0)
50
- pv_df = pd.read_excel(pv_file.name, index_col=0)
51
  print(f"[{pd.Timestamp.now()}] Files read successfully.")
52
  print(f"Policy data shape: {policy_df.shape}, Columns: {policy_df.columns.tolist()}")
53
  print(f"Cashflow data shape: {cashflow_df.shape}, Index type: {cashflow_df.index.dtype}")
@@ -60,10 +73,13 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
60
 
61
  # 2. Validate Policy Data Columns
62
  required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
 
 
 
63
  if not all(col in policy_df.columns for col in required_cols):
64
  found_cols = policy_df.columns.tolist()
65
  error_msg = (f"Policy data missing required columns. Expected: {required_cols}. "
66
- f"Found: {found_cols}. Please check your policy Excel file column headers.")
67
  print(f"[{pd.Timestamp.now()}] {error_msg}")
68
  return (None, None, None, error_msg)
69
 
@@ -71,7 +87,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
71
  try:
72
  X = policy_df[required_cols].fillna(0)
73
  # Scale data, handle cases where std is 0 (e.g., all values are the same for a feature)
74
- X_scaled = X.apply(lambda x: (x - x.mean()) / x.std() if x.std() != 0 else 0, axis=0)
 
75
  print(f"[{pd.Timestamp.now()}] Policy attributes scaled.")
76
  except Exception as e:
77
  error_msg = f"Error preparing data for clustering: {e}"
@@ -91,7 +108,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
91
  print(f"[{pd.Timestamp.now()}] Warning: Number of clusters ({original_num_clusters}) "
92
  f"exceeded number of samples ({n_samples}). Reduced to {num_clusters}.")
93
 
94
- kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
 
95
  kmeans.fit(X_scaled)
96
  policy_df['Cluster'] = kmeans.labels_
97
  print(f"[{pd.Timestamp.now()}] Clustering successful with {num_clusters} clusters.")
@@ -109,7 +127,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
109
  counts = policy_df['Cluster'].value_counts()
110
  model_points['Weight'] = model_points['Cluster'].map(counts)
111
  print(f"[{pd.Timestamp.now()}] Model points selected and weights calculated. Model points shape: {model_points.shape}")
112
- print(f"Model points indices: {model_points.index.tolist()}")
113
  except Exception as e:
114
  error_msg = f"Error selecting model points or calculating weights: {e}"
115
  print(f"[{pd.Timestamp.now()}] {error_msg}")
@@ -126,15 +144,15 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
126
  # Check if all model_points indices exist in cashflow_df and pv_df
127
  missing_cf_indices = [idx for idx in model_points.index if idx not in cashflow_df.index]
128
  if missing_cf_indices:
129
- raise KeyError(f"Cashflow data is missing entries for model point indices: {missing_cf_indices[:5]}...") # Show first 5
130
 
131
  proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
132
  seriatim_cashflows = cashflow_df.sum()
133
  print(f"[{pd.Timestamp.now()}] Cashflows aggregated.")
134
  except KeyError as e:
135
  error_msg = (f"Key Error during cashflow aggregation. "
136
- f"Ensure 'policy_id' (or equivalent) column from your policy data "
137
- f"is set as the index (first column) in your Cashflow Excel file: {e}")
138
  print(f"[{pd.Timestamp.now()}] {error_msg}")
139
  return (None, None, None, error_msg)
140
  except Exception as e:
@@ -169,7 +187,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
169
  try:
170
  missing_pv_indices = [idx for idx in model_points.index if idx not in pv_df.index]
171
  if missing_pv_indices:
172
- raise KeyError(f"PV data is missing entries for model point indices: {missing_pv_indices[:5]}...") # Show first 5
173
 
174
  # Assuming PV data has only one column or the relevant column is the first one
175
  proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
@@ -177,8 +195,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
177
  print(f"[{pd.Timestamp.now()}] Present Values aggregated.")
178
  except KeyError as e:
179
  error_msg = (f"Key Error during PV aggregation. "
180
- f"Ensure 'policy_id' (or equivalent) column from your policy data "
181
- f"is set as the index (first column) in your PV Excel file: {e}")
182
  print(f"[{pd.Timestamp.now()}] {error_msg}")
183
  return (None, None, None, error_msg)
184
  except Exception as e:
@@ -210,7 +228,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
210
  try:
211
  common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
212
  if common_idx.empty:
213
- r2 = float('nan')
214
  print(f"[{pd.Timestamp.now()}] Warning: No common indices for R-squared calculation.")
215
  else:
216
  r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
@@ -240,29 +258,28 @@ with gr.Blocks() as demo:
240
  to the seriatim (full portfolio) results, providing accuracy metrics and visualizations.
241
 
242
  **Instructions:**
243
- 1. **Upload Policy Data (Excel file):** Ensure it contains columns named exactly `IssueAge`, `PolicyTerm`, `SumAssured`, and `Duration`. The first column should ideally be a unique policy identifier that matches the indices in Cashflow and PV files.
244
- 2. **Upload Cashflow Data (Excel file):** The **first column** of this file must be a unique policy identifier (like `policy_id`), and it will be used as the DataFrame index. The remaining columns should be cashflow periods.
245
- 3. **Upload Present Value Data (Excel file):** The **first column** of this file must also be a unique policy identifier, matching the policy data's identifiers. The second column should contain the present value for each policy.
246
  4. Adjust the 'Number of Model Points' using the slider.
247
  5. Click 'Run Clustering'.
248
  """)
249
 
250
  with gr.Row():
251
  with gr.Column():
252
- policy_input = gr.File(label="1. Upload Policy Data (Excel - .xlsx/.xls)", file_types=[".xlsx", ".xls"])
253
- cashflow_input = gr.File(label="2. Upload Cashflow Data (Excel - .xlsx/.xls, first column is Policy ID)", file_types=[".xlsx", ".xls"])
254
- pv_input = gr.File(label="3. Upload Present Value Data (Excel - .xlsx/.xls, first column is Policy ID)", file_types=[".xlsx", ".xls"])
 
255
  clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="4. Number of Model Points")
256
  run_btn = gr.Button("Run Clustering", variant="primary")
257
 
258
  with gr.Column():
259
- # For displaying the CSV content directly. If you want a downloadable file, use gr.File(file_to_share=True)
260
  output_csv = gr.Textbox(label="Model Points CSV Output (Scroll to view)", lines=10, interactive=False)
261
  cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", interactive=False)
262
  pv_img = gr.Image(label="Aggregated Present Values Comparison", interactive=False)
263
  metrics_box = gr.Textbox(label="Accuracy Metrics and Status", lines=4, interactive=False)
264
 
265
- # Link the button click to the function
266
  run_btn.click(
267
  cluster_analysis,
268
  inputs=[policy_input, cashflow_input, pv_input, clusters_input],
 
5
  from sklearn.metrics import r2_score, pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
7
  import io
8
+ import os # For checking file extensions
9
 
10
  def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
11
  """
12
  Performs cluster analysis for actuarial model point selection.
13
+ Accepts both Excel and CSV files.
14
 
15
  Args:
16
+ policy_file: Gradio File object for policy data.
17
+ cashflow_file: Gradio File object for cashflow data.
18
+ pv_file: Gradio File object for present value data.
19
  num_clusters: Number of clusters (model points) to generate.
20
 
21
  Returns:
 
37
  print(f"PV file received: {pv_file.name if pv_file else 'None'}")
38
  print("="*50 + "\n")
39
 
40
+ # Helper function to read files based on extension
41
+ def read_data_file(file_obj, index_col=None):
42
+ if file_obj is None:
43
+ raise ValueError("File object is None.")
44
+
45
+ file_path = file_obj.name
46
+ file_extension = os.path.splitext(file_path)[1].lower()
47
+
48
+ if file_extension in ['.xlsx', '.xls']:
49
+ print(f"Attempting to read Excel file: {file_path}")
50
+ return pd.read_excel(file_path, index_col=index_col)
51
+ elif file_extension == '.csv':
52
+ print(f"Attempting to read CSV file: {file_path}")
53
+ # Consider adding 'sep' argument if CSV delimiter is not comma, e.g., sep=';'
54
+ return pd.read_csv(file_path, index_col=index_col)
55
+ else:
56
+ raise ValueError(f"Unsupported file type: {file_extension}. Please upload .xlsx, .xls, or .csv files.")
57
+
58
  # 1. Basic checks and file reading
59
  try:
60
+ policy_df = read_data_file(policy_file)
 
 
 
 
 
 
 
61
  # index_col=0 is crucial. Ensure the first column contains unique policy identifiers.
62
+ cashflow_df = read_data_file(cashflow_file, index_col=0)
63
+ pv_df = read_data_file(pv_file, index_col=0)
64
  print(f"[{pd.Timestamp.now()}] Files read successfully.")
65
  print(f"Policy data shape: {policy_df.shape}, Columns: {policy_df.columns.tolist()}")
66
  print(f"Cashflow data shape: {cashflow_df.shape}, Index type: {cashflow_df.index.dtype}")
 
73
 
74
  # 2. Validate Policy Data Columns
75
  required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
76
+ # Strip whitespace from column names for robust matching
77
+ policy_df.columns = policy_df.columns.str.strip()
78
+
79
  if not all(col in policy_df.columns for col in required_cols):
80
  found_cols = policy_df.columns.tolist()
81
  error_msg = (f"Policy data missing required columns. Expected: {required_cols}. "
82
+ f"Found: {found_cols}. Please check your policy data column headers for typos or extra spaces.")
83
  print(f"[{pd.Timestamp.now()}] {error_msg}")
84
  return (None, None, None, error_msg)
85
 
 
87
  try:
88
  X = policy_df[required_cols].fillna(0)
89
  # Scale data, handle cases where std is 0 (e.g., all values are the same for a feature)
90
+ # Add a small epsilon to avoid division by zero if all values are identical
91
+ X_scaled = X.apply(lambda x: (x - x.mean()) / (x.std() if x.std() != 0 else 1e-9), axis=0)
92
  print(f"[{pd.Timestamp.now()}] Policy attributes scaled.")
93
  except Exception as e:
94
  error_msg = f"Error preparing data for clustering: {e}"
 
108
  print(f"[{pd.Timestamp.now()}] Warning: Number of clusters ({original_num_clusters}) "
109
  f"exceeded number of samples ({n_samples}). Reduced to {num_clusters}.")
110
 
111
+ # Use 'auto' for n_init for newer scikit-learn versions
112
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
113
  kmeans.fit(X_scaled)
114
  policy_df['Cluster'] = kmeans.labels_
115
  print(f"[{pd.Timestamp.now()}] Clustering successful with {num_clusters} clusters.")
 
127
  counts = policy_df['Cluster'].value_counts()
128
  model_points['Weight'] = model_points['Cluster'].map(counts)
129
  print(f"[{pd.Timestamp.now()}] Model points selected and weights calculated. Model points shape: {model_points.shape}")
130
+ print(f"Model points indices (first 5): {model_points.index.tolist()[:5]}...")
131
  except Exception as e:
132
  error_msg = f"Error selecting model points or calculating weights: {e}"
133
  print(f"[{pd.Timestamp.now()}] {error_msg}")
 
144
  # Check if all model_points indices exist in cashflow_df and pv_df
145
  missing_cf_indices = [idx for idx in model_points.index if idx not in cashflow_df.index]
146
  if missing_cf_indices:
147
+ raise KeyError(f"Cashflow data is missing entries for model point indices: {missing_cf_indices[:5]}... Please check Cashflow data's first column (index).")
148
 
149
  proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
150
  seriatim_cashflows = cashflow_df.sum()
151
  print(f"[{pd.Timestamp.now()}] Cashflows aggregated.")
152
  except KeyError as e:
153
  error_msg = (f"Key Error during cashflow aggregation. "
154
+ f"Ensure the first column of your Cashflow Excel/CSV file contains policy IDs "
155
+ f"that match the indices from your Policy data: {e}")
156
  print(f"[{pd.Timestamp.now()}] {error_msg}")
157
  return (None, None, None, error_msg)
158
  except Exception as e:
 
187
  try:
188
  missing_pv_indices = [idx for idx in model_points.index if idx not in pv_df.index]
189
  if missing_pv_indices:
190
+ raise KeyError(f"PV data is missing entries for model point indices: {missing_pv_indices[:5]}... Please check PV data's first column (index).")
191
 
192
  # Assuming PV data has only one column or the relevant column is the first one
193
  proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
 
195
  print(f"[{pd.Timestamp.now()}] Present Values aggregated.")
196
  except KeyError as e:
197
  error_msg = (f"Key Error during PV aggregation. "
198
+ f"Ensure the first column of your PV Excel/CSV file contains policy IDs "
199
+ f"that match the indices from your Policy data: {e}")
200
  print(f"[{pd.Timestamp.now()}] {error_msg}")
201
  return (None, None, None, error_msg)
202
  except Exception as e:
 
228
  try:
229
  common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
230
  if common_idx.empty:
231
+ r2 = float('nan') # Cannot compute R2 if no common cashflow periods
232
  print(f"[{pd.Timestamp.now()}] Warning: No common indices for R-squared calculation.")
233
  else:
234
  r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
 
258
  to the seriatim (full portfolio) results, providing accuracy metrics and visualizations.
259
 
260
  **Instructions:**
261
+ 1. **Upload Policy Data (Excel or CSV file):** Ensure it contains columns named exactly `IssueAge`, `PolicyTerm`, `SumAssured`, and `Duration`. **Crucially, double-check for leading/trailing spaces in column names.** The first column can be a unique policy identifier (though not explicitly used for clustering, it helps with index matching).
262
+ 2. **Upload Cashflow Data (Excel or CSV file):** The **first column** of this file must be a unique policy identifier (e.g., `policy_id`), and this column will be used as the DataFrame index. The remaining columns should represent cashflow periods (e.g., `CF_Year_1`, `CF_Year_2`).
263
+ 3. **Upload Present Value Data (Excel or CSV file):** The **first column** of this file must also be a unique policy identifier, matching the policy data's identifiers. The second column should contain the present value for each policy.
264
  4. Adjust the 'Number of Model Points' using the slider.
265
  5. Click 'Run Clustering'.
266
  """)
267
 
268
  with gr.Row():
269
  with gr.Column():
270
+ # Updated file_types to include CSV
271
+ policy_input = gr.File(label="1. Upload Policy Data (Excel/CSV)", file_types=[".xlsx", ".xls", ".csv"])
272
+ cashflow_input = gr.File(label="2. Upload Cashflow Data (Excel/CSV, first column is Policy ID)", file_types=[".xlsx", ".xls", ".csv"])
273
+ pv_input = gr.File(label="3. Upload Present Value Data (Excel/CSV, first column is Policy ID)", file_types=[".xlsx", ".xls", ".csv"])
274
  clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="4. Number of Model Points")
275
  run_btn = gr.Button("Run Clustering", variant="primary")
276
 
277
  with gr.Column():
 
278
  output_csv = gr.Textbox(label="Model Points CSV Output (Scroll to view)", lines=10, interactive=False)
279
  cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", interactive=False)
280
  pv_img = gr.Image(label="Aggregated Present Values Comparison", interactive=False)
281
  metrics_box = gr.Textbox(label="Accuracy Metrics and Status", lines=4, interactive=False)
282
 
 
283
  run_btn.click(
284
  cluster_analysis,
285
  inputs=[policy_input, cashflow_input, pv_input, clusters_input],