Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.cluster import KMeans
|
5 |
+
from sklearn.metrics import r2_score, pairwise_distances_argmin_min
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import io
|
8 |
+
|
9 |
+
def run_cluster_analysis(
|
10 |
+
policy_data_file,
|
11 |
+
cashflow_data_file,
|
12 |
+
pv_data_file,
|
13 |
+
num_clusters,
|
14 |
+
clustering_variable_choice
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Performs cluster analysis for model point selection and generates comparative outputs.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
policy_data_file: Gradio File object for policy attributes data (e.g., issue age, policy term).
|
21 |
+
cashflow_data_file: Gradio File object for seriatim cashflow data.
|
22 |
+
pv_data_file: Gradio File object for seriatim present value data.
|
23 |
+
num_clusters: The desired number of representative model points (k for K-means).
|
24 |
+
clustering_variable_choice: The type of variables to use for clustering
|
25 |
+
("Net Cashflows", "Policy Attributes", "Present Values").
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A tuple containing:
|
29 |
+
- A CSV string of the selected model points with their weights.
|
30 |
+
- A BytesIO object containing a PNG image of the cashflow comparison plot.
|
31 |
+
- A BytesIO object containing a PNG image of the present value comparison plot.
|
32 |
+
- A string summarizing key accuracy metrics.
|
33 |
+
"""
|
34 |
+
|
35 |
+
# --- 1. Load Data ---
|
36 |
+
# Actuaries: Please ensure your Excel files have the correct format and column names.
|
37 |
+
# The notebook mentions 'policy data', 'cashflow data', and 'present value data'.
|
38 |
+
# For this app, we assume these are Excel files.
|
39 |
+
try:
|
40 |
+
# Load policy data; assuming policy identifiers are implicitly handled or not index.
|
41 |
+
policy_data = pd.read_excel(policy_data_file.name)
|
42 |
+
|
43 |
+
# Load cashflow data; assuming policy identifiers are in the first column or index.
|
44 |
+
# The notebook implies policies as rows and periods as columns for cashflows.
|
45 |
+
cashflow_data = pd.read_excel(cashflow_data_file.name, index_col=0)
|
46 |
+
|
47 |
+
# Load present value data; assuming policy identifiers are in the first column or index.
|
48 |
+
# The notebook implies policies as rows and PV components as columns, or a single PV column.
|
49 |
+
pv_data = pd.read_excel(pv_data_file.name, index_col=0)
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
return (f"Error loading files: {e}. Please ensure you upload valid Excel files "
|
53 |
+
"with appropriate data (e.g., policy IDs as index for cashflows/PVs).",
|
54 |
+
None, None, None)
|
55 |
+
|
56 |
+
# --- 2. Data Preparation for Clustering ---
|
57 |
+
X = pd.DataFrame() # Initialize an empty DataFrame for clustering variables
|
58 |
+
|
59 |
+
if clustering_variable_choice == "Policy Attributes":
|
60 |
+
# Actuaries: Adjust these column names to match your policy_data.xlsx file.
|
61 |
+
# The notebook mentions 'issue age', 'policy term', 'sum assured', and 'duration'.
|
62 |
+
required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
|
63 |
+
if not all(col in policy_data.columns for col in required_cols):
|
64 |
+
return (f"Missing expected columns in Policy Data for 'Policy Attributes' clustering. "
|
65 |
+
f"Expected: {required_cols}. Please adjust your file or the code.",
|
66 |
+
None, None, None)
|
67 |
+
X = policy_data[required_cols]
|
68 |
+
elif clustering_variable_choice == "Net Cashflows":
|
69 |
+
# The notebook uses the full cashflow series for clustering.
|
70 |
+
# Ensure cashflow_data is purely numerical and represents cashflows over time.
|
71 |
+
X = cashflow_data.fillna(0) # Handle potential NaN values
|
72 |
+
elif clustering_variable_choice == "Present Values":
|
73 |
+
# Actuaries: Adjust this column name to match your pv_data.xlsx file.
|
74 |
+
# The notebook implies a main present value column (e.g., 'PV_Net_CF').
|
75 |
+
required_col = 'PV_Net_CF'
|
76 |
+
if required_col not in pv_data.columns:
|
77 |
+
return (f"Missing expected column '{required_col}' in Present Value Data for 'Present Values' clustering. "
|
78 |
+
"Please adjust your file or the code.",
|
79 |
+
None, None, None)
|
80 |
+
X = pv_data[[required_col]]
|
81 |
+
else:
|
82 |
+
return "Invalid clustering variable choice.", None, None, None
|
83 |
+
|
84 |
+
# Ensure policy_data, cashflow_data, and pv_data have a common index for merging later
|
85 |
+
# If not, you might need to merge them based on a common 'PolicyID' column
|
86 |
+
# For this example, we assume they all share a common index (e.g., policy IDs).
|
87 |
+
if not all(df.index.equals(policy_data.index) for df in [cashflow_data, pv_data]):
|
88 |
+
# If indexes are not aligned, try to align them by a common 'PolicyID' column if available.
|
89 |
+
# For simplicity in this demo, we'll assume they are aligned by index for now.
|
90 |
+
# A robust solution would involve merging or re-indexing.
|
91 |
+
pass # No action, assume alignment for now.
|
92 |
+
|
93 |
+
# Standardize data for clustering to prevent features with large values from dominating
|
94 |
+
# This is a common practice in k-means.
|
95 |
+
X_scaled = (X - X.mean()) / X.std()
|
96 |
+
X_scaled = X_scaled.fillna(0) # Replace NaNs after scaling (e.g., for columns with zero standard deviation)
|
97 |
+
|
98 |
+
# --- 3. K-means Clustering ---
|
99 |
+
try:
|
100 |
+
# n_init='auto' (default in scikit-learn 1.4+) or n_init=10 (for older versions)
|
101 |
+
# provides more robust centroid initialization.
|
102 |
+
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
|
103 |
+
kmeans.fit(X_scaled)
|
104 |
+
|
105 |
+
# Assign cluster labels back to the original policy data
|
106 |
+
policy_data['Cluster'] = kmeans.labels_
|
107 |
+
except Exception as e:
|
108 |
+
return f"Error during K-means clustering: {e}", None, None, None
|
109 |
+
|
110 |
+
# --- 4. Select Representative Model Points and Calculate Weights ---
|
111 |
+
# Find the policy closest to each cluster centroid to represent that cluster.
|
112 |
+
# `pairwise_distances_argmin_min` returns indices of closest points.
|
113 |
+
closest_policies_indices = pairwise_distances_argmin_min(kmeans.cluster_centers_, X_scaled)[0]
|
114 |
+
|
115 |
+
# Select the actual policy data for these representative points.
|
116 |
+
model_points = policy_data.iloc[closest_policies_indices].copy()
|
117 |
+
|
118 |
+
# Calculate weights for each model point: count of original policies in its cluster.
|
119 |
+
cluster_counts = policy_data['Cluster'].value_counts()
|
120 |
+
model_points['Weight'] = model_points['Cluster'].map(cluster_counts)
|
121 |
+
|
122 |
+
# --- 5. Aggregate Cashflows and Present Values for comparison ---
|
123 |
+
# Compare aggregated results of seriatim portfolio vs. proxy portfolio.
|
124 |
+
|
125 |
+
# Total aggregated cashflows from the original seriatim data
|
126 |
+
total_seriatim_cashflows = cashflow_data.sum(axis=0)
|
127 |
+
|
128 |
+
# Total aggregated present values from the original seriatim data
|
129 |
+
total_seriatim_pvs = pv_data.sum(axis=0)
|
130 |
+
|
131 |
+
# Calculate proxy aggregated cashflows and present values
|
132 |
+
# We multiply the cashflows/PVs of the selected model points by their calculated weights.
|
133 |
+
# Ensure the indices align for correct multiplication.
|
134 |
+
proxy_cashflows = cashflow_data.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum(axis=0)
|
135 |
+
proxy_pvs = pv_data.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum(axis=0)
|
136 |
+
|
137 |
+
|
138 |
+
# --- 6. Generate Outputs ---
|
139 |
+
|
140 |
+
# Prepare model points for download as a CSV file.
|
141 |
+
model_points_output = model_points.to_csv(index=False)
|
142 |
+
|
143 |
+
# --- Plotting Aggregated Cashflows ---
|
144 |
+
fig_cf, ax_cf = plt.subplots(figsize=(10, 6))
|
145 |
+
total_seriatim_cashflows.plot(ax=ax_cf, label='Seriatim Cashflows', color='blue')
|
146 |
+
proxy_cashflows.plot(ax=ax_cf, label='Proxy Cashflows', linestyle='--', color='orange')
|
147 |
+
ax_cf.set_title('Aggregated Cashflows: Seriatim vs. Proxy')
|
148 |
+
ax_cf.set_xlabel('Projection Period')
|
149 |
+
ax_cf.set_ylabel('Cashflow Amount')
|
150 |
+
ax_cf.legend()
|
151 |
+
ax_cf.grid(True)
|
152 |
+
|
153 |
+
# Save plot to a BytesIO object
|
154 |
+
buf_cf = io.BytesIO()
|
155 |
+
plt.savefig(buf_cf, format='png')
|
156 |
+
plt.close(fig_cf) # Close the figure to free up memory
|
157 |
+
buf_cf.seek(0) # Reset buffer position
|
158 |
+
img_cf = buf_cf.read()
|
159 |
+
|
160 |
+
# --- Plotting Aggregated Present Values ---
|
161 |
+
fig_pv, ax_pv = plt.subplots(figsize=(8, 5))
|
162 |
+
pv_comparison_data = pd.DataFrame({
|
163 |
+
'Seriatim PV': total_seriatim_pvs.iloc[0] if isinstance(total_seriatim_pvs, pd.Series) else total_seriatim_pvs,
|
164 |
+
'Proxy PV': proxy_pvs.iloc[0] if isinstance(proxy_pvs, pd.Series) else proxy_pvs
|
165 |
+
}, index=['Total PV']) # Use a dummy index for plotting if it's a single value
|
166 |
+
|
167 |
+
pv_comparison_data.plot(kind='bar', ax=ax_pv, color=['blue', 'orange'])
|
168 |
+
ax_pv.set_title('Aggregated Present Values: Seriatim vs. Proxy')
|
169 |
+
ax_pv.set_ylabel('Present Value')
|
170 |
+
ax_pv.tick_params(axis='x', rotation=45)
|
171 |
+
ax_pv.legend()
|
172 |
+
ax_pv.grid(axis='y')
|
173 |
+
|
174 |
+
# Save plot to a BytesIO object
|
175 |
+
buf_pv = io.BytesIO()
|
176 |
+
plt.savefig(buf_pv, format='png')
|
177 |
+
plt.close(fig_pv) # Close the figure to free up memory
|
178 |
+
buf_pv.seek(0) # Reset buffer position
|
179 |
+
img_pv = buf_pv.read()
|
180 |
+
|
181 |
+
# --- Accuracy Metrics ---
|
182 |
+
# Calculate R-squared for cashflows to measure goodness of fit.
|
183 |
+
common_periods_cf = total_seriatim_cashflows.index.intersection(proxy_cashflows.index)
|
184 |
+
r2_cf = r2_score(total_seriatim_cashflows[common_periods_cf], proxy_cashflows[common_periods_cf])
|
185 |
+
|
186 |
+
# Calculate absolute percentage error for Present Values.
|
187 |
+
seriatim_total_pv_val = total_seriatim_pvs.iloc[0] if isinstance(total_seriatim_pvs, pd.Series) else total_seriatim_pvs
|
188 |
+
proxy_total_pv_val = proxy_pvs.iloc[0] if isinstance(proxy_pvs, pd.Series) else proxy_pvs
|
189 |
+
|
190 |
+
if seriatim_total_pv_val == 0:
|
191 |
+
pv_error_percent = float('inf') # Handle division by zero
|
192 |
+
else:
|
193 |
+
pv_error_percent = abs((proxy_total_pv_val - seriatim_total_pv_val) / seriatim_total_pv_val) * 100
|
194 |
+
|
195 |
+
metrics_output = (
|
196 |
+
f"--- Accuracy Metrics ---\n"
|
197 |
+
f"R-squared (Aggregated Cashflows): {r2_cf:.4f}\n"
|
198 |
+
f"Absolute Percentage Error (Aggregated Present Value): {pv_error_percent:.4f}%\n\n"
|
199 |
+
f"Note: The acceptable error percentage for Present Value should be specified in practice (e.g., 1%).\n"
|
200 |
+
f"For better accuracy, consider trying different 'Number of Model Points' and 'Clustering Variables'."
|
201 |
+
)
|
202 |
+
|
203 |
+
return model_points_output, img_cf, img_pv, metrics_output
|
204 |
+
|
205 |
+
# Gradio Interface Setup
|
206 |
+
# Using a minimalistic theme with default black and orange colors and default font.
|
207 |
+
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange", secondary_hue="black", font="default")) as demo:
|
208 |
+
gr.Markdown("# <center> Actuarial Model Point Selection using Cluster Analysis </center>")
|
209 |
+
gr.Markdown("This app helps actuaries select representative model points from a large portfolio "
|
210 |
+
"using K-means clustering.")
|
211 |
+
gr.Markdown("Upload your policy data, cashflow data, and present value data. "
|
212 |
+
"Then, configure the clustering parameters to generate representative model points and "
|
213 |
+
"analyze the accuracy of the proxy portfolio.")
|
214 |
+
|
215 |
+
with gr.Row():
|
216 |
+
with gr.Column():
|
217 |
+
gr.Markdown("### Input Data (Excel Files)")
|
218 |
+
policy_data_input = gr.File(
|
219 |
+
label="Upload Policy Data (e.g., policy_data.xlsx)",
|
220 |
+
file_types=[".xlsx", ".xls"],
|
221 |
+
type="filepath",
|
222 |
+
info="Contains policy attributes like Issue Age, Policy Term, Sum Assured, Duration."
|
223 |
+
)
|
224 |
+
cashflow_data_input = gr.File(
|
225 |
+
label="Upload Base Scenario Cashflow Data (e.g., cashflows_seriatim_10K.xlsx)",
|
226 |
+
file_types=[".xlsx", ".xls"],
|
227 |
+
type="filepath",
|
228 |
+
info="Net annual cashflows for each seriatim policy over projection periods. "
|
229 |
+
"Policies as rows, periods as columns. First column as Policy ID/Index."
|
230 |
+
)
|
231 |
+
pv_data_input = gr.File(
|
232 |
+
label="Upload Base Scenario Present Value Data (e.g., pvs_seriatim_10K.xlsx)",
|
233 |
+
file_types=[".xlsx", ".xls"],
|
234 |
+
type="filepath",
|
235 |
+
info="Present values for each seriatim policy. "
|
236 |
+
"Policies as rows, PV components as columns. First column as Policy ID/Index. "
|
237 |
+
"Expected column for total PV: 'PV_Net_CF'."
|
238 |
+
)
|
239 |
+
|
240 |
+
with gr.Column():
|
241 |
+
gr.Markdown("### Clustering Parameters")
|
242 |
+
num_clusters_input = gr.Slider(
|
243 |
+
minimum=10,
|
244 |
+
maximum=2000, # A reasonable range for 10,000 policies, can be adjusted
|
245 |
+
value=1000, # Default based on the notebook's example (1000 out of 10,000 policies)
|
246 |
+
step=10,
|
247 |
+
label="Number of Representative Model Points (k)",
|
248 |
+
info="This determines the size of the proxy portfolio. Higher values may increase accuracy but reduce efficiency."
|
249 |
+
)
|
250 |
+
clustering_variable_choice_input = gr.Dropdown(
|
251 |
+
choices=["Net Cashflows", "Policy Attributes", "Present Values"],
|
252 |
+
value="Present Values", # Notebook indicates Present Values often yield best results for PV estimation
|
253 |
+
label="Variables for Clustering",
|
254 |
+
info="The choice of variables significantly impacts results. "
|
255 |
+
"The chosen variables are more accurately estimated by the proxy portfolio."
|
256 |
+
)
|
257 |
+
process_button = gr.Button("Run Cluster Analysis")
|
258 |
+
|
259 |
+
with gr.Tab("Results"):
|
260 |
+
gr.Markdown("### Selected Model Points")
|
261 |
+
gr.Markdown("Download the CSV below to get the representative model points and their assigned weights. "
|
262 |
+
"These can be used to construct a proxy portfolio.")
|
263 |
+
model_points_output = gr.File(label="Download Selected Model Points (CSV)", file_types=[".csv"])
|
264 |
+
|
265 |
+
gr.Markdown("### Aggregated Cashflows Comparison")
|
266 |
+
gr.Markdown("This plot compares the total aggregated cashflows from your original seriatim portfolio "
|
267 |
+
"against the aggregated cashflows generated by the selected proxy model points.")
|
268 |
+
cashflow_plot_output = gr.Image(label="Seriatim vs. Proxy Aggregated Cashflows")
|
269 |
+
|
270 |
+
gr.Markdown("### Aggregated Present Values Comparison")
|
271 |
+
gr.Markdown("This plot compares the total aggregated present values of your original seriatim portfolio "
|
272 |
+
"against the aggregated present values generated by the selected proxy model points.")
|
273 |
+
pv_plot_output = gr.Image(label="Seriatim vs. Proxy Aggregated Present Values")
|
274 |
+
|
275 |
+
gr.Markdown("### Accuracy Summary")
|
276 |
+
gr.Markdown("Key metrics to assess how well the proxy portfolio represents the seriatim portfolio.")
|
277 |
+
metrics_output = gr.Textbox(label="Key Accuracy Metrics", lines=7)
|
278 |
+
|
279 |
+
|
280 |
+
process_button.click(
|
281 |
+
run_cluster_analysis,
|
282 |
+
inputs=[
|
283 |
+
policy_data_input,
|
284 |
+
cashflow_data_input,
|
285 |
+
pv_data_input,
|
286 |
+
num_clusters_input,
|
287 |
+
clustering_variable_choice_input
|
288 |
+
],
|
289 |
+
outputs=[
|
290 |
+
model_points_output,
|
291 |
+
cashflow_plot_output,
|
292 |
+
pv_plot_output,
|
293 |
+
metrics_output
|
294 |
+
]
|
295 |
+
)
|
296 |
+
|
297 |
+
demo.launch(debug=True) # Set debug=True for local testing and more verbose output
|