Update app.py
Browse files
app.py
CHANGED
@@ -5,15 +5,17 @@ from sklearn.cluster import KMeans
|
|
5 |
from sklearn.metrics import r2_score, pairwise_distances_argmin_min
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
|
|
8 |
|
9 |
def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
10 |
"""
|
11 |
Performs cluster analysis for actuarial model point selection.
|
|
|
12 |
|
13 |
Args:
|
14 |
-
policy_file: Gradio File object for policy data
|
15 |
-
cashflow_file: Gradio File object for cashflow data
|
16 |
-
pv_file: Gradio File object for present value data
|
17 |
num_clusters: Number of clusters (model points) to generate.
|
18 |
|
19 |
Returns:
|
@@ -35,19 +37,30 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
35 |
print(f"PV file received: {pv_file.name if pv_file else 'None'}")
|
36 |
print("="*50 + "\n")
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# 1. Basic checks and file reading
|
39 |
try:
|
40 |
-
|
41 |
-
missing_files = []
|
42 |
-
if policy_file is None: missing_files.append("Policy Data")
|
43 |
-
if cashflow_file is None: missing_files.append("Cashflow Data")
|
44 |
-
if pv_file is None: missing_files.append("Present Value Data")
|
45 |
-
raise ValueError(f"Missing required input file(s): {', '.join(missing_files)}. Please upload all files.")
|
46 |
-
|
47 |
-
policy_df = pd.read_excel(policy_file.name)
|
48 |
# index_col=0 is crucial. Ensure the first column contains unique policy identifiers.
|
49 |
-
cashflow_df =
|
50 |
-
pv_df =
|
51 |
print(f"[{pd.Timestamp.now()}] Files read successfully.")
|
52 |
print(f"Policy data shape: {policy_df.shape}, Columns: {policy_df.columns.tolist()}")
|
53 |
print(f"Cashflow data shape: {cashflow_df.shape}, Index type: {cashflow_df.index.dtype}")
|
@@ -60,10 +73,13 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
60 |
|
61 |
# 2. Validate Policy Data Columns
|
62 |
required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
|
|
|
|
|
|
|
63 |
if not all(col in policy_df.columns for col in required_cols):
|
64 |
found_cols = policy_df.columns.tolist()
|
65 |
error_msg = (f"Policy data missing required columns. Expected: {required_cols}. "
|
66 |
-
f"Found: {found_cols}. Please check your policy
|
67 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
68 |
return (None, None, None, error_msg)
|
69 |
|
@@ -71,7 +87,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
71 |
try:
|
72 |
X = policy_df[required_cols].fillna(0)
|
73 |
# Scale data, handle cases where std is 0 (e.g., all values are the same for a feature)
|
74 |
-
|
|
|
75 |
print(f"[{pd.Timestamp.now()}] Policy attributes scaled.")
|
76 |
except Exception as e:
|
77 |
error_msg = f"Error preparing data for clustering: {e}"
|
@@ -91,7 +108,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
91 |
print(f"[{pd.Timestamp.now()}] Warning: Number of clusters ({original_num_clusters}) "
|
92 |
f"exceeded number of samples ({n_samples}). Reduced to {num_clusters}.")
|
93 |
|
94 |
-
|
|
|
95 |
kmeans.fit(X_scaled)
|
96 |
policy_df['Cluster'] = kmeans.labels_
|
97 |
print(f"[{pd.Timestamp.now()}] Clustering successful with {num_clusters} clusters.")
|
@@ -109,7 +127,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
109 |
counts = policy_df['Cluster'].value_counts()
|
110 |
model_points['Weight'] = model_points['Cluster'].map(counts)
|
111 |
print(f"[{pd.Timestamp.now()}] Model points selected and weights calculated. Model points shape: {model_points.shape}")
|
112 |
-
print(f"Model points indices: {model_points.index.tolist()}")
|
113 |
except Exception as e:
|
114 |
error_msg = f"Error selecting model points or calculating weights: {e}"
|
115 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
@@ -126,15 +144,15 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
126 |
# Check if all model_points indices exist in cashflow_df and pv_df
|
127 |
missing_cf_indices = [idx for idx in model_points.index if idx not in cashflow_df.index]
|
128 |
if missing_cf_indices:
|
129 |
-
raise KeyError(f"Cashflow data is missing entries for model point indices: {missing_cf_indices[:5]}...
|
130 |
|
131 |
proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
|
132 |
seriatim_cashflows = cashflow_df.sum()
|
133 |
print(f"[{pd.Timestamp.now()}] Cashflows aggregated.")
|
134 |
except KeyError as e:
|
135 |
error_msg = (f"Key Error during cashflow aggregation. "
|
136 |
-
f"Ensure
|
137 |
-
f"
|
138 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
139 |
return (None, None, None, error_msg)
|
140 |
except Exception as e:
|
@@ -169,7 +187,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
169 |
try:
|
170 |
missing_pv_indices = [idx for idx in model_points.index if idx not in pv_df.index]
|
171 |
if missing_pv_indices:
|
172 |
-
raise KeyError(f"PV data is missing entries for model point indices: {missing_pv_indices[:5]}...
|
173 |
|
174 |
# Assuming PV data has only one column or the relevant column is the first one
|
175 |
proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
|
@@ -177,8 +195,8 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
177 |
print(f"[{pd.Timestamp.now()}] Present Values aggregated.")
|
178 |
except KeyError as e:
|
179 |
error_msg = (f"Key Error during PV aggregation. "
|
180 |
-
f"Ensure
|
181 |
-
f"
|
182 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
183 |
return (None, None, None, error_msg)
|
184 |
except Exception as e:
|
@@ -210,7 +228,7 @@ def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
|
210 |
try:
|
211 |
common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
|
212 |
if common_idx.empty:
|
213 |
-
r2 = float('nan')
|
214 |
print(f"[{pd.Timestamp.now()}] Warning: No common indices for R-squared calculation.")
|
215 |
else:
|
216 |
r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
|
@@ -240,29 +258,28 @@ with gr.Blocks() as demo:
|
|
240 |
to the seriatim (full portfolio) results, providing accuracy metrics and visualizations.
|
241 |
|
242 |
**Instructions:**
|
243 |
-
1. **Upload Policy Data (Excel file):** Ensure it contains columns named exactly `IssueAge`, `PolicyTerm`, `SumAssured`, and `Duration`. The first column
|
244 |
-
2. **Upload Cashflow Data (Excel file):** The **first column** of this file must be a unique policy identifier (
|
245 |
-
3. **Upload Present Value Data (Excel file):** The **first column** of this file must also be a unique policy identifier, matching the policy data's identifiers. The second column should contain the present value for each policy.
|
246 |
4. Adjust the 'Number of Model Points' using the slider.
|
247 |
5. Click 'Run Clustering'.
|
248 |
""")
|
249 |
|
250 |
with gr.Row():
|
251 |
with gr.Column():
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
255 |
clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="4. Number of Model Points")
|
256 |
run_btn = gr.Button("Run Clustering", variant="primary")
|
257 |
|
258 |
with gr.Column():
|
259 |
-
# For displaying the CSV content directly. If you want a downloadable file, use gr.File(file_to_share=True)
|
260 |
output_csv = gr.Textbox(label="Model Points CSV Output (Scroll to view)", lines=10, interactive=False)
|
261 |
cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", interactive=False)
|
262 |
pv_img = gr.Image(label="Aggregated Present Values Comparison", interactive=False)
|
263 |
metrics_box = gr.Textbox(label="Accuracy Metrics and Status", lines=4, interactive=False)
|
264 |
|
265 |
-
# Link the button click to the function
|
266 |
run_btn.click(
|
267 |
cluster_analysis,
|
268 |
inputs=[policy_input, cashflow_input, pv_input, clusters_input],
|
|
|
5 |
from sklearn.metrics import r2_score, pairwise_distances_argmin_min
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
8 |
+
import os # For checking file extensions
|
9 |
|
10 |
def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
11 |
"""
|
12 |
Performs cluster analysis for actuarial model point selection.
|
13 |
+
Accepts both Excel and CSV files.
|
14 |
|
15 |
Args:
|
16 |
+
policy_file: Gradio File object for policy data.
|
17 |
+
cashflow_file: Gradio File object for cashflow data.
|
18 |
+
pv_file: Gradio File object for present value data.
|
19 |
num_clusters: Number of clusters (model points) to generate.
|
20 |
|
21 |
Returns:
|
|
|
37 |
print(f"PV file received: {pv_file.name if pv_file else 'None'}")
|
38 |
print("="*50 + "\n")
|
39 |
|
40 |
+
# Helper function to read files based on extension
|
41 |
+
def read_data_file(file_obj, index_col=None):
|
42 |
+
if file_obj is None:
|
43 |
+
raise ValueError("File object is None.")
|
44 |
+
|
45 |
+
file_path = file_obj.name
|
46 |
+
file_extension = os.path.splitext(file_path)[1].lower()
|
47 |
+
|
48 |
+
if file_extension in ['.xlsx', '.xls']:
|
49 |
+
print(f"Attempting to read Excel file: {file_path}")
|
50 |
+
return pd.read_excel(file_path, index_col=index_col)
|
51 |
+
elif file_extension == '.csv':
|
52 |
+
print(f"Attempting to read CSV file: {file_path}")
|
53 |
+
# Consider adding 'sep' argument if CSV delimiter is not comma, e.g., sep=';'
|
54 |
+
return pd.read_csv(file_path, index_col=index_col)
|
55 |
+
else:
|
56 |
+
raise ValueError(f"Unsupported file type: {file_extension}. Please upload .xlsx, .xls, or .csv files.")
|
57 |
+
|
58 |
# 1. Basic checks and file reading
|
59 |
try:
|
60 |
+
policy_df = read_data_file(policy_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# index_col=0 is crucial. Ensure the first column contains unique policy identifiers.
|
62 |
+
cashflow_df = read_data_file(cashflow_file, index_col=0)
|
63 |
+
pv_df = read_data_file(pv_file, index_col=0)
|
64 |
print(f"[{pd.Timestamp.now()}] Files read successfully.")
|
65 |
print(f"Policy data shape: {policy_df.shape}, Columns: {policy_df.columns.tolist()}")
|
66 |
print(f"Cashflow data shape: {cashflow_df.shape}, Index type: {cashflow_df.index.dtype}")
|
|
|
73 |
|
74 |
# 2. Validate Policy Data Columns
|
75 |
required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
|
76 |
+
# Strip whitespace from column names for robust matching
|
77 |
+
policy_df.columns = policy_df.columns.str.strip()
|
78 |
+
|
79 |
if not all(col in policy_df.columns for col in required_cols):
|
80 |
found_cols = policy_df.columns.tolist()
|
81 |
error_msg = (f"Policy data missing required columns. Expected: {required_cols}. "
|
82 |
+
f"Found: {found_cols}. Please check your policy data column headers for typos or extra spaces.")
|
83 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
84 |
return (None, None, None, error_msg)
|
85 |
|
|
|
87 |
try:
|
88 |
X = policy_df[required_cols].fillna(0)
|
89 |
# Scale data, handle cases where std is 0 (e.g., all values are the same for a feature)
|
90 |
+
# Add a small epsilon to avoid division by zero if all values are identical
|
91 |
+
X_scaled = X.apply(lambda x: (x - x.mean()) / (x.std() if x.std() != 0 else 1e-9), axis=0)
|
92 |
print(f"[{pd.Timestamp.now()}] Policy attributes scaled.")
|
93 |
except Exception as e:
|
94 |
error_msg = f"Error preparing data for clustering: {e}"
|
|
|
108 |
print(f"[{pd.Timestamp.now()}] Warning: Number of clusters ({original_num_clusters}) "
|
109 |
f"exceeded number of samples ({n_samples}). Reduced to {num_clusters}.")
|
110 |
|
111 |
+
# Use 'auto' for n_init for newer scikit-learn versions
|
112 |
+
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
|
113 |
kmeans.fit(X_scaled)
|
114 |
policy_df['Cluster'] = kmeans.labels_
|
115 |
print(f"[{pd.Timestamp.now()}] Clustering successful with {num_clusters} clusters.")
|
|
|
127 |
counts = policy_df['Cluster'].value_counts()
|
128 |
model_points['Weight'] = model_points['Cluster'].map(counts)
|
129 |
print(f"[{pd.Timestamp.now()}] Model points selected and weights calculated. Model points shape: {model_points.shape}")
|
130 |
+
print(f"Model points indices (first 5): {model_points.index.tolist()[:5]}...")
|
131 |
except Exception as e:
|
132 |
error_msg = f"Error selecting model points or calculating weights: {e}"
|
133 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
|
|
144 |
# Check if all model_points indices exist in cashflow_df and pv_df
|
145 |
missing_cf_indices = [idx for idx in model_points.index if idx not in cashflow_df.index]
|
146 |
if missing_cf_indices:
|
147 |
+
raise KeyError(f"Cashflow data is missing entries for model point indices: {missing_cf_indices[:5]}... Please check Cashflow data's first column (index).")
|
148 |
|
149 |
proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
|
150 |
seriatim_cashflows = cashflow_df.sum()
|
151 |
print(f"[{pd.Timestamp.now()}] Cashflows aggregated.")
|
152 |
except KeyError as e:
|
153 |
error_msg = (f"Key Error during cashflow aggregation. "
|
154 |
+
f"Ensure the first column of your Cashflow Excel/CSV file contains policy IDs "
|
155 |
+
f"that match the indices from your Policy data: {e}")
|
156 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
157 |
return (None, None, None, error_msg)
|
158 |
except Exception as e:
|
|
|
187 |
try:
|
188 |
missing_pv_indices = [idx for idx in model_points.index if idx not in pv_df.index]
|
189 |
if missing_pv_indices:
|
190 |
+
raise KeyError(f"PV data is missing entries for model point indices: {missing_pv_indices[:5]}... Please check PV data's first column (index).")
|
191 |
|
192 |
# Assuming PV data has only one column or the relevant column is the first one
|
193 |
proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
|
|
|
195 |
print(f"[{pd.Timestamp.now()}] Present Values aggregated.")
|
196 |
except KeyError as e:
|
197 |
error_msg = (f"Key Error during PV aggregation. "
|
198 |
+
f"Ensure the first column of your PV Excel/CSV file contains policy IDs "
|
199 |
+
f"that match the indices from your Policy data: {e}")
|
200 |
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
201 |
return (None, None, None, error_msg)
|
202 |
except Exception as e:
|
|
|
228 |
try:
|
229 |
common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
|
230 |
if common_idx.empty:
|
231 |
+
r2 = float('nan') # Cannot compute R2 if no common cashflow periods
|
232 |
print(f"[{pd.Timestamp.now()}] Warning: No common indices for R-squared calculation.")
|
233 |
else:
|
234 |
r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
|
|
|
258 |
to the seriatim (full portfolio) results, providing accuracy metrics and visualizations.
|
259 |
|
260 |
**Instructions:**
|
261 |
+
1. **Upload Policy Data (Excel or CSV file):** Ensure it contains columns named exactly `IssueAge`, `PolicyTerm`, `SumAssured`, and `Duration`. **Crucially, double-check for leading/trailing spaces in column names.** The first column can be a unique policy identifier (though not explicitly used for clustering, it helps with index matching).
|
262 |
+
2. **Upload Cashflow Data (Excel or CSV file):** The **first column** of this file must be a unique policy identifier (e.g., `policy_id`), and this column will be used as the DataFrame index. The remaining columns should represent cashflow periods (e.g., `CF_Year_1`, `CF_Year_2`).
|
263 |
+
3. **Upload Present Value Data (Excel or CSV file):** The **first column** of this file must also be a unique policy identifier, matching the policy data's identifiers. The second column should contain the present value for each policy.
|
264 |
4. Adjust the 'Number of Model Points' using the slider.
|
265 |
5. Click 'Run Clustering'.
|
266 |
""")
|
267 |
|
268 |
with gr.Row():
|
269 |
with gr.Column():
|
270 |
+
# Updated file_types to include CSV
|
271 |
+
policy_input = gr.File(label="1. Upload Policy Data (Excel/CSV)", file_types=[".xlsx", ".xls", ".csv"])
|
272 |
+
cashflow_input = gr.File(label="2. Upload Cashflow Data (Excel/CSV, first column is Policy ID)", file_types=[".xlsx", ".xls", ".csv"])
|
273 |
+
pv_input = gr.File(label="3. Upload Present Value Data (Excel/CSV, first column is Policy ID)", file_types=[".xlsx", ".xls", ".csv"])
|
274 |
clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="4. Number of Model Points")
|
275 |
run_btn = gr.Button("Run Clustering", variant="primary")
|
276 |
|
277 |
with gr.Column():
|
|
|
278 |
output_csv = gr.Textbox(label="Model Points CSV Output (Scroll to view)", lines=10, interactive=False)
|
279 |
cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", interactive=False)
|
280 |
pv_img = gr.Image(label="Aggregated Present Values Comparison", interactive=False)
|
281 |
metrics_box = gr.Textbox(label="Accuracy Metrics and Status", lines=4, interactive=False)
|
282 |
|
|
|
283 |
run_btn.click(
|
284 |
cluster_analysis,
|
285 |
inputs=[policy_input, cashflow_input, pv_input, clusters_input],
|