alidenewade commited on
Commit
7e17387
·
verified ·
1 Parent(s): e54138b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.cluster import KMeans
5
+ from sklearn.metrics import r2_score, pairwise_distances_argmin_min
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+
9
+ def run_cluster_analysis(
10
+ policy_data_file,
11
+ cashflow_data_file,
12
+ pv_data_file,
13
+ num_clusters,
14
+ clustering_variable_choice
15
+ ):
16
+ """
17
+ Performs cluster analysis for model point selection and generates comparative outputs.
18
+
19
+ Args:
20
+ policy_data_file: Gradio File object for policy attributes data (e.g., issue age, policy term).
21
+ cashflow_data_file: Gradio File object for seriatim cashflow data.
22
+ pv_data_file: Gradio File object for seriatim present value data.
23
+ num_clusters: The desired number of representative model points (k for K-means).
24
+ clustering_variable_choice: The type of variables to use for clustering
25
+ ("Net Cashflows", "Policy Attributes", "Present Values").
26
+
27
+ Returns:
28
+ A tuple containing:
29
+ - A CSV string of the selected model points with their weights.
30
+ - A BytesIO object containing a PNG image of the cashflow comparison plot.
31
+ - A BytesIO object containing a PNG image of the present value comparison plot.
32
+ - A string summarizing key accuracy metrics.
33
+ """
34
+
35
+ # --- 1. Load Data ---
36
+ # Actuaries: Please ensure your Excel files have the correct format and column names.
37
+ # The notebook mentions 'policy data', 'cashflow data', and 'present value data'.
38
+ # For this app, we assume these are Excel files.
39
+ try:
40
+ # Load policy data; assuming policy identifiers are implicitly handled or not index.
41
+ policy_data = pd.read_excel(policy_data_file.name)
42
+
43
+ # Load cashflow data; assuming policy identifiers are in the first column or index.
44
+ # The notebook implies policies as rows and periods as columns for cashflows.
45
+ cashflow_data = pd.read_excel(cashflow_data_file.name, index_col=0)
46
+
47
+ # Load present value data; assuming policy identifiers are in the first column or index.
48
+ # The notebook implies policies as rows and PV components as columns, or a single PV column.
49
+ pv_data = pd.read_excel(pv_data_file.name, index_col=0)
50
+
51
+ except Exception as e:
52
+ return (f"Error loading files: {e}. Please ensure you upload valid Excel files "
53
+ "with appropriate data (e.g., policy IDs as index for cashflows/PVs).",
54
+ None, None, None)
55
+
56
+ # --- 2. Data Preparation for Clustering ---
57
+ X = pd.DataFrame() # Initialize an empty DataFrame for clustering variables
58
+
59
+ if clustering_variable_choice == "Policy Attributes":
60
+ # Actuaries: Adjust these column names to match your policy_data.xlsx file.
61
+ # The notebook mentions 'issue age', 'policy term', 'sum assured', and 'duration'.
62
+ required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
63
+ if not all(col in policy_data.columns for col in required_cols):
64
+ return (f"Missing expected columns in Policy Data for 'Policy Attributes' clustering. "
65
+ f"Expected: {required_cols}. Please adjust your file or the code.",
66
+ None, None, None)
67
+ X = policy_data[required_cols]
68
+ elif clustering_variable_choice == "Net Cashflows":
69
+ # The notebook uses the full cashflow series for clustering.
70
+ # Ensure cashflow_data is purely numerical and represents cashflows over time.
71
+ X = cashflow_data.fillna(0) # Handle potential NaN values
72
+ elif clustering_variable_choice == "Present Values":
73
+ # Actuaries: Adjust this column name to match your pv_data.xlsx file.
74
+ # The notebook implies a main present value column (e.g., 'PV_Net_CF').
75
+ required_col = 'PV_Net_CF'
76
+ if required_col not in pv_data.columns:
77
+ return (f"Missing expected column '{required_col}' in Present Value Data for 'Present Values' clustering. "
78
+ "Please adjust your file or the code.",
79
+ None, None, None)
80
+ X = pv_data[[required_col]]
81
+ else:
82
+ return "Invalid clustering variable choice.", None, None, None
83
+
84
+ # Ensure policy_data, cashflow_data, and pv_data have a common index for merging later
85
+ # If not, you might need to merge them based on a common 'PolicyID' column
86
+ # For this example, we assume they all share a common index (e.g., policy IDs).
87
+ if not all(df.index.equals(policy_data.index) for df in [cashflow_data, pv_data]):
88
+ # If indexes are not aligned, try to align them by a common 'PolicyID' column if available.
89
+ # For simplicity in this demo, we'll assume they are aligned by index for now.
90
+ # A robust solution would involve merging or re-indexing.
91
+ pass # No action, assume alignment for now.
92
+
93
+ # Standardize data for clustering to prevent features with large values from dominating
94
+ # This is a common practice in k-means.
95
+ X_scaled = (X - X.mean()) / X.std()
96
+ X_scaled = X_scaled.fillna(0) # Replace NaNs after scaling (e.g., for columns with zero standard deviation)
97
+
98
+ # --- 3. K-means Clustering ---
99
+ try:
100
+ # n_init='auto' (default in scikit-learn 1.4+) or n_init=10 (for older versions)
101
+ # provides more robust centroid initialization.
102
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
103
+ kmeans.fit(X_scaled)
104
+
105
+ # Assign cluster labels back to the original policy data
106
+ policy_data['Cluster'] = kmeans.labels_
107
+ except Exception as e:
108
+ return f"Error during K-means clustering: {e}", None, None, None
109
+
110
+ # --- 4. Select Representative Model Points and Calculate Weights ---
111
+ # Find the policy closest to each cluster centroid to represent that cluster.
112
+ # `pairwise_distances_argmin_min` returns indices of closest points.
113
+ closest_policies_indices = pairwise_distances_argmin_min(kmeans.cluster_centers_, X_scaled)[0]
114
+
115
+ # Select the actual policy data for these representative points.
116
+ model_points = policy_data.iloc[closest_policies_indices].copy()
117
+
118
+ # Calculate weights for each model point: count of original policies in its cluster.
119
+ cluster_counts = policy_data['Cluster'].value_counts()
120
+ model_points['Weight'] = model_points['Cluster'].map(cluster_counts)
121
+
122
+ # --- 5. Aggregate Cashflows and Present Values for comparison ---
123
+ # Compare aggregated results of seriatim portfolio vs. proxy portfolio.
124
+
125
+ # Total aggregated cashflows from the original seriatim data
126
+ total_seriatim_cashflows = cashflow_data.sum(axis=0)
127
+
128
+ # Total aggregated present values from the original seriatim data
129
+ total_seriatim_pvs = pv_data.sum(axis=0)
130
+
131
+ # Calculate proxy aggregated cashflows and present values
132
+ # We multiply the cashflows/PVs of the selected model points by their calculated weights.
133
+ # Ensure the indices align for correct multiplication.
134
+ proxy_cashflows = cashflow_data.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum(axis=0)
135
+ proxy_pvs = pv_data.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum(axis=0)
136
+
137
+
138
+ # --- 6. Generate Outputs ---
139
+
140
+ # Prepare model points for download as a CSV file.
141
+ model_points_output = model_points.to_csv(index=False)
142
+
143
+ # --- Plotting Aggregated Cashflows ---
144
+ fig_cf, ax_cf = plt.subplots(figsize=(10, 6))
145
+ total_seriatim_cashflows.plot(ax=ax_cf, label='Seriatim Cashflows', color='blue')
146
+ proxy_cashflows.plot(ax=ax_cf, label='Proxy Cashflows', linestyle='--', color='orange')
147
+ ax_cf.set_title('Aggregated Cashflows: Seriatim vs. Proxy')
148
+ ax_cf.set_xlabel('Projection Period')
149
+ ax_cf.set_ylabel('Cashflow Amount')
150
+ ax_cf.legend()
151
+ ax_cf.grid(True)
152
+
153
+ # Save plot to a BytesIO object
154
+ buf_cf = io.BytesIO()
155
+ plt.savefig(buf_cf, format='png')
156
+ plt.close(fig_cf) # Close the figure to free up memory
157
+ buf_cf.seek(0) # Reset buffer position
158
+ img_cf = buf_cf.read()
159
+
160
+ # --- Plotting Aggregated Present Values ---
161
+ fig_pv, ax_pv = plt.subplots(figsize=(8, 5))
162
+ pv_comparison_data = pd.DataFrame({
163
+ 'Seriatim PV': total_seriatim_pvs.iloc[0] if isinstance(total_seriatim_pvs, pd.Series) else total_seriatim_pvs,
164
+ 'Proxy PV': proxy_pvs.iloc[0] if isinstance(proxy_pvs, pd.Series) else proxy_pvs
165
+ }, index=['Total PV']) # Use a dummy index for plotting if it's a single value
166
+
167
+ pv_comparison_data.plot(kind='bar', ax=ax_pv, color=['blue', 'orange'])
168
+ ax_pv.set_title('Aggregated Present Values: Seriatim vs. Proxy')
169
+ ax_pv.set_ylabel('Present Value')
170
+ ax_pv.tick_params(axis='x', rotation=45)
171
+ ax_pv.legend()
172
+ ax_pv.grid(axis='y')
173
+
174
+ # Save plot to a BytesIO object
175
+ buf_pv = io.BytesIO()
176
+ plt.savefig(buf_pv, format='png')
177
+ plt.close(fig_pv) # Close the figure to free up memory
178
+ buf_pv.seek(0) # Reset buffer position
179
+ img_pv = buf_pv.read()
180
+
181
+ # --- Accuracy Metrics ---
182
+ # Calculate R-squared for cashflows to measure goodness of fit.
183
+ common_periods_cf = total_seriatim_cashflows.index.intersection(proxy_cashflows.index)
184
+ r2_cf = r2_score(total_seriatim_cashflows[common_periods_cf], proxy_cashflows[common_periods_cf])
185
+
186
+ # Calculate absolute percentage error for Present Values.
187
+ seriatim_total_pv_val = total_seriatim_pvs.iloc[0] if isinstance(total_seriatim_pvs, pd.Series) else total_seriatim_pvs
188
+ proxy_total_pv_val = proxy_pvs.iloc[0] if isinstance(proxy_pvs, pd.Series) else proxy_pvs
189
+
190
+ if seriatim_total_pv_val == 0:
191
+ pv_error_percent = float('inf') # Handle division by zero
192
+ else:
193
+ pv_error_percent = abs((proxy_total_pv_val - seriatim_total_pv_val) / seriatim_total_pv_val) * 100
194
+
195
+ metrics_output = (
196
+ f"--- Accuracy Metrics ---\n"
197
+ f"R-squared (Aggregated Cashflows): {r2_cf:.4f}\n"
198
+ f"Absolute Percentage Error (Aggregated Present Value): {pv_error_percent:.4f}%\n\n"
199
+ f"Note: The acceptable error percentage for Present Value should be specified in practice (e.g., 1%).\n"
200
+ f"For better accuracy, consider trying different 'Number of Model Points' and 'Clustering Variables'."
201
+ )
202
+
203
+ return model_points_output, img_cf, img_pv, metrics_output
204
+
205
+ # Gradio Interface Setup
206
+ # Using a minimalistic theme with default black and orange colors and default font.
207
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="orange", secondary_hue="black", font="default")) as demo:
208
+ gr.Markdown("# <center> Actuarial Model Point Selection using Cluster Analysis </center>")
209
+ gr.Markdown("This app helps actuaries select representative model points from a large portfolio "
210
+ "using K-means clustering.")
211
+ gr.Markdown("Upload your policy data, cashflow data, and present value data. "
212
+ "Then, configure the clustering parameters to generate representative model points and "
213
+ "analyze the accuracy of the proxy portfolio.")
214
+
215
+ with gr.Row():
216
+ with gr.Column():
217
+ gr.Markdown("### Input Data (Excel Files)")
218
+ policy_data_input = gr.File(
219
+ label="Upload Policy Data (e.g., policy_data.xlsx)",
220
+ file_types=[".xlsx", ".xls"],
221
+ type="filepath",
222
+ info="Contains policy attributes like Issue Age, Policy Term, Sum Assured, Duration."
223
+ )
224
+ cashflow_data_input = gr.File(
225
+ label="Upload Base Scenario Cashflow Data (e.g., cashflows_seriatim_10K.xlsx)",
226
+ file_types=[".xlsx", ".xls"],
227
+ type="filepath",
228
+ info="Net annual cashflows for each seriatim policy over projection periods. "
229
+ "Policies as rows, periods as columns. First column as Policy ID/Index."
230
+ )
231
+ pv_data_input = gr.File(
232
+ label="Upload Base Scenario Present Value Data (e.g., pvs_seriatim_10K.xlsx)",
233
+ file_types=[".xlsx", ".xls"],
234
+ type="filepath",
235
+ info="Present values for each seriatim policy. "
236
+ "Policies as rows, PV components as columns. First column as Policy ID/Index. "
237
+ "Expected column for total PV: 'PV_Net_CF'."
238
+ )
239
+
240
+ with gr.Column():
241
+ gr.Markdown("### Clustering Parameters")
242
+ num_clusters_input = gr.Slider(
243
+ minimum=10,
244
+ maximum=2000, # A reasonable range for 10,000 policies, can be adjusted
245
+ value=1000, # Default based on the notebook's example (1000 out of 10,000 policies)
246
+ step=10,
247
+ label="Number of Representative Model Points (k)",
248
+ info="This determines the size of the proxy portfolio. Higher values may increase accuracy but reduce efficiency."
249
+ )
250
+ clustering_variable_choice_input = gr.Dropdown(
251
+ choices=["Net Cashflows", "Policy Attributes", "Present Values"],
252
+ value="Present Values", # Notebook indicates Present Values often yield best results for PV estimation
253
+ label="Variables for Clustering",
254
+ info="The choice of variables significantly impacts results. "
255
+ "The chosen variables are more accurately estimated by the proxy portfolio."
256
+ )
257
+ process_button = gr.Button("Run Cluster Analysis")
258
+
259
+ with gr.Tab("Results"):
260
+ gr.Markdown("### Selected Model Points")
261
+ gr.Markdown("Download the CSV below to get the representative model points and their assigned weights. "
262
+ "These can be used to construct a proxy portfolio.")
263
+ model_points_output = gr.File(label="Download Selected Model Points (CSV)", file_types=[".csv"])
264
+
265
+ gr.Markdown("### Aggregated Cashflows Comparison")
266
+ gr.Markdown("This plot compares the total aggregated cashflows from your original seriatim portfolio "
267
+ "against the aggregated cashflows generated by the selected proxy model points.")
268
+ cashflow_plot_output = gr.Image(label="Seriatim vs. Proxy Aggregated Cashflows")
269
+
270
+ gr.Markdown("### Aggregated Present Values Comparison")
271
+ gr.Markdown("This plot compares the total aggregated present values of your original seriatim portfolio "
272
+ "against the aggregated present values generated by the selected proxy model points.")
273
+ pv_plot_output = gr.Image(label="Seriatim vs. Proxy Aggregated Present Values")
274
+
275
+ gr.Markdown("### Accuracy Summary")
276
+ gr.Markdown("Key metrics to assess how well the proxy portfolio represents the seriatim portfolio.")
277
+ metrics_output = gr.Textbox(label="Key Accuracy Metrics", lines=7)
278
+
279
+
280
+ process_button.click(
281
+ run_cluster_analysis,
282
+ inputs=[
283
+ policy_data_input,
284
+ cashflow_data_input,
285
+ pv_data_input,
286
+ num_clusters_input,
287
+ clustering_variable_choice_input
288
+ ],
289
+ outputs=[
290
+ model_points_output,
291
+ cashflow_plot_output,
292
+ pv_plot_output,
293
+ metrics_output
294
+ ]
295
+ )
296
+
297
+ demo.launch(debug=True) # Set debug=True for local testing and more verbose output