Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
from sklearn.cluster import KMeans
|
5 |
-
from sklearn.metrics import pairwise_distances_argmin_min
|
6 |
import matplotlib.pyplot as plt
|
7 |
import matplotlib.cm
|
8 |
import io
|
@@ -22,16 +22,41 @@ EXAMPLE_FILES = {
|
|
22 |
}
|
23 |
|
24 |
class Clusters:
|
25 |
-
def __init__(self,
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
|
30 |
-
rep_ids.name = 'policy_id'
|
31 |
-
rep_ids.index.name = 'cluster_id'
|
32 |
-
self.rep_ids = rep_ids
|
33 |
|
34 |
-
|
|
|
|
|
35 |
|
36 |
def agg_by_cluster(self, df, agg=None):
|
37 |
"""Aggregate columns by cluster"""
|
@@ -43,21 +68,46 @@ class Clusters:
|
|
43 |
|
44 |
def extract_reps(self, df):
|
45 |
"""Extract the rows of representative policies"""
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def extract_and_scale_reps(self, df, agg=None):
|
51 |
"""Extract and scale the rows of representative policies"""
|
|
|
52 |
if agg:
|
53 |
-
cols =
|
54 |
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
|
55 |
-
|
56 |
-
extracted_df = self.extract_reps(df)
|
57 |
-
mult.index = extracted_df.index
|
58 |
return extracted_df.mul(mult)
|
59 |
else:
|
60 |
-
return
|
61 |
|
62 |
def compare(self, df, agg=None):
|
63 |
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
|
@@ -68,7 +118,6 @@ class Clusters:
|
|
68 |
def compare_total(self, df, agg=None):
|
69 |
"""Aggregate df by columns"""
|
70 |
if agg:
|
71 |
-
# Calculate actual values using specified aggregation
|
72 |
actual_values = {}
|
73 |
for col in df.columns:
|
74 |
if agg.get(col, 'sum') == 'mean':
|
@@ -77,13 +126,19 @@ class Clusters:
|
|
77 |
actual_values[col] = df[col].sum()
|
78 |
actual = pd.Series(actual_values)
|
79 |
|
80 |
-
# Calculate estimate values
|
81 |
reps_unscaled = self.extract_reps(df)
|
82 |
estimate_values = {}
|
83 |
|
84 |
-
for col in df.columns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if agg.get(col, 'sum') == 'mean':
|
86 |
-
# Weighted average for mean columns
|
87 |
weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
|
88 |
total_weight = self.policy_count.sum()
|
89 |
estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
|
@@ -96,12 +151,14 @@ class Clusters:
|
|
96 |
actual = df.sum()
|
97 |
estimate = self.extract_and_scale_reps(df).sum()
|
98 |
|
99 |
-
#
|
|
|
100 |
error = np.where(actual != 0, estimate / actual - 1, 0)
|
101 |
|
102 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
103 |
|
104 |
|
|
|
105 |
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
106 |
"""Create cashflow comparison plots"""
|
107 |
if not cfs_list or not cluster_obj or not titles:
|
@@ -110,7 +167,6 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
|
110 |
if num_plots == 0:
|
111 |
return None
|
112 |
|
113 |
-
# Determine subplot layout
|
114 |
cols = 2
|
115 |
rows = (num_plots + cols - 1) // cols
|
116 |
|
@@ -119,12 +175,17 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
|
119 |
|
120 |
for i, (df, title) in enumerate(zip(cfs_list, titles)):
|
121 |
if i < len(axes):
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
|
124 |
axes[i].set_xlabel('Time')
|
125 |
axes[i].set_ylabel('Value')
|
126 |
|
127 |
-
# Hide any unused subplots
|
128 |
for j in range(i + 1, len(axes)):
|
129 |
fig.delaxes(axes[j])
|
130 |
|
@@ -139,7 +200,6 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
|
139 |
def plot_scatter_comparison(df_compare_output, title):
|
140 |
"""Create scatter plot comparison from compare() output"""
|
141 |
if df_compare_output is None or df_compare_output.empty:
|
142 |
-
# Create a blank plot with a message
|
143 |
fig, ax = plt.subplots(figsize=(12, 8))
|
144 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
145 |
ax.set_title(title)
|
@@ -153,29 +213,28 @@ def plot_scatter_comparison(df_compare_output, title):
|
|
153 |
fig, ax = plt.subplots(figsize=(12, 8))
|
154 |
|
155 |
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
|
156 |
-
|
157 |
-
|
158 |
else:
|
159 |
unique_levels = df_compare_output.index.get_level_values(1).unique()
|
160 |
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
|
161 |
|
162 |
for item_level, color_val in zip(unique_levels, colors):
|
163 |
subset = df_compare_output.xs(item_level, level=1)
|
164 |
-
ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
|
165 |
-
if len(unique_levels) > 1 and len(unique_levels) <=
|
166 |
-
ax.legend(title=df_compare_output.index.names[1])
|
167 |
|
168 |
ax.set_xlabel('Actual')
|
169 |
ax.set_ylabel('Estimate')
|
170 |
ax.set_title(title)
|
171 |
ax.grid(True)
|
172 |
|
173 |
-
# Draw identity line
|
174 |
lims = [
|
175 |
-
np.
|
176 |
-
np.
|
177 |
]
|
178 |
-
if lims[0] != lims[1]:
|
179 |
ax.plot(lims, lims, 'r-', linewidth=0.5)
|
180 |
ax.set_xlim(lims)
|
181 |
ax.set_ylim(lims)
|
@@ -187,28 +246,55 @@ def plot_scatter_comparison(df_compare_output, title):
|
|
187 |
plt.close(fig)
|
188 |
return img
|
189 |
|
190 |
-
|
|
|
191 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
192 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
193 |
"""Main processing function - now accepts file paths"""
|
194 |
try:
|
195 |
-
#
|
|
|
196 |
cfs = pd.read_excel(cashflow_base_path, index_col=0)
|
197 |
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
|
198 |
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
|
199 |
|
200 |
pol_data_full = pd.read_excel(policy_data_path, index_col=0)
|
201 |
-
# Ensure the correct columns are selected for pol_data
|
202 |
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
else:
|
206 |
-
|
207 |
-
|
|
|
208 |
|
209 |
pvs = pd.read_excel(pv_base_path, index_col=0)
|
210 |
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
|
211 |
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
214 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
@@ -217,8 +303,14 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
217 |
|
218 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
219 |
|
|
|
|
|
|
|
|
|
|
|
220 |
# --- 1. Cashflow Calibration ---
|
221 |
-
|
|
|
222 |
|
223 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
224 |
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
@@ -231,15 +323,22 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
231 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
232 |
|
233 |
# --- 2. Policy Attribute Calibration ---
|
234 |
-
|
235 |
-
if not pol_data.empty
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
else:
|
238 |
-
gr.Warning("Policy data
|
239 |
-
|
240 |
-
|
241 |
if not loc_vars_attrs.empty:
|
242 |
-
cluster_attrs = Clusters(loc_vars_attrs)
|
243 |
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
|
244 |
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
|
245 |
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
|
@@ -249,11 +348,12 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
249 |
results['attr_total_cf_base'] = pd.DataFrame()
|
250 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
251 |
results['attr_total_pv_base'] = pd.DataFrame()
|
252 |
-
results['attr_cashflow_plot'] =
|
253 |
-
results['attr_scatter_cashflows_base'] =
|
|
|
254 |
|
255 |
# --- 3. Present Value Calibration ---
|
256 |
-
cluster_pvs = Clusters(pvs)
|
257 |
|
258 |
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
|
259 |
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
|
@@ -266,52 +366,47 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
266 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
267 |
|
268 |
# --- Summary Comparison Plot Data ---
|
269 |
-
# Error metric for key PV column or mean absolute error
|
270 |
-
|
271 |
error_data = {}
|
272 |
|
273 |
-
# Function to safely get error value
|
274 |
def get_error_safe(compare_result, col_name=None):
|
275 |
-
if compare_result.empty:
|
276 |
return np.nan
|
277 |
if col_name and col_name in compare_result.index:
|
278 |
return abs(compare_result.loc[col_name, 'error'])
|
279 |
else:
|
280 |
-
# Use mean absolute error if specific column not found
|
281 |
return abs(compare_result['error']).mean()
|
282 |
|
283 |
-
# Determine key PV column (try common names)
|
284 |
key_pv_col = None
|
|
|
|
|
|
|
|
|
285 |
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
|
286 |
-
if potential_col in
|
287 |
key_pv_col = potential_col
|
288 |
break
|
289 |
|
290 |
-
# Cashflow Calibration Errors
|
291 |
error_data['CF Calib.'] = [
|
292 |
-
get_error_safe(
|
293 |
-
get_error_safe(
|
294 |
-
get_error_safe(
|
295 |
]
|
296 |
|
297 |
-
# Policy Attribute Calibration Errors
|
298 |
if not loc_vars_attrs.empty:
|
299 |
error_data['Attr Calib.'] = [
|
300 |
-
get_error_safe(
|
301 |
-
get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
|
302 |
-
get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
|
303 |
]
|
304 |
else:
|
305 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
306 |
|
307 |
-
# Present Value Calibration Errors
|
308 |
error_data['PV Calib.'] = [
|
309 |
-
get_error_safe(
|
310 |
-
get_error_safe(
|
311 |
-
get_error_safe(
|
312 |
]
|
313 |
|
314 |
-
# Create Summary Plot
|
315 |
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
|
316 |
|
317 |
fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
|
@@ -335,13 +430,15 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
335 |
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
|
336 |
return {"error": f"File not found: {e.filename}"}
|
337 |
except KeyError as e:
|
338 |
-
|
339 |
-
|
|
|
340 |
except Exception as e:
|
341 |
-
|
|
|
342 |
return {"error": f"Error processing files: {str(e)}"}
|
343 |
|
344 |
-
|
345 |
def create_interface():
|
346 |
with gr.Blocks(title="Cluster Model Points Analysis") as demo:
|
347 |
gr.Markdown("""
|
@@ -351,13 +448,15 @@ def create_interface():
|
|
351 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
352 |
|
353 |
**Required Files (Excel .xlsx):**
|
354 |
-
- Cashflows - Base Scenario
|
355 |
-
- Cashflows - Lapse Stress (+50%)
|
356 |
-
- Cashflows - Mortality Stress (+15%)
|
357 |
-
- Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
|
358 |
-
- Present Values - Base Scenario
|
359 |
-
- Present Values - Lapse Stress
|
360 |
-
- Present Values - Mortality Stress
|
|
|
|
|
361 |
""")
|
362 |
|
363 |
with gr.Row():
|
@@ -404,7 +503,11 @@ def create_interface():
|
|
404 |
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
405 |
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
406 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
407 |
-
|
|
|
|
|
|
|
|
|
408 |
|
409 |
with gr.TabItem("💰 Present Value Calibration"):
|
410 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
@@ -441,24 +544,28 @@ def create_interface():
|
|
441 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
442 |
|
443 |
file_paths = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
for i, f_obj in enumerate(files):
|
445 |
-
|
446 |
-
|
447 |
-
return [None] * len(get_all_output_components())
|
448 |
-
|
449 |
-
# If f_obj is a Gradio FileData object (from direct upload)
|
450 |
-
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
|
451 |
file_paths.append(f_obj.name)
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
else:
|
456 |
gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
|
457 |
return [None] * len(get_all_output_components())
|
458 |
|
459 |
results = process_files(*file_paths)
|
460 |
|
461 |
-
if "error" in results:
|
|
|
462 |
return [None] * len(get_all_output_components())
|
463 |
|
464 |
return [
|
@@ -485,18 +592,48 @@ def create_interface():
|
|
485 |
|
486 |
# --- Action for Load Example Data Button ---
|
487 |
def load_example_files():
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
|
493 |
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
|
|
|
494 |
return [
|
495 |
-
EXAMPLE_FILES["cashflow_base"],
|
496 |
-
EXAMPLE_FILES["
|
497 |
-
EXAMPLE_FILES["
|
|
|
|
|
|
|
|
|
498 |
]
|
499 |
|
|
|
500 |
load_example_btn.click(
|
501 |
load_example_files,
|
502 |
inputs=[],
|
@@ -509,8 +646,30 @@ def create_interface():
|
|
509 |
if __name__ == "__main__":
|
510 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
511 |
os.makedirs(EXAMPLE_DATA_DIR)
|
512 |
-
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
|
513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
|
515 |
demo_app = create_interface()
|
516 |
demo_app.launch()
|
|
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
from sklearn.cluster import KMeans
|
5 |
+
from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the provided snippet.
|
6 |
import matplotlib.pyplot as plt
|
7 |
import matplotlib.cm
|
8 |
import io
|
|
|
22 |
}
|
23 |
|
24 |
class Clusters:
|
25 |
+
def __init__(self, loc_vars_df): # Expecting a pandas DataFrame
|
26 |
+
# "Quantisize" by converting input DataFrame to float32 for KMeans.
|
27 |
+
# This reduces precision, potentially speeding up calculations and lowering memory.
|
28 |
+
# Results might have minor numerical differences compared to float64.
|
29 |
+
# Ensure data is a C-contiguous NumPy array.
|
30 |
+
if loc_vars_df.empty:
|
31 |
+
# Handle empty DataFrame case to avoid errors with .values or astype
|
32 |
+
# KMeans would fail anyway, but this prevents issues before that.
|
33 |
+
loc_vars_np_float32 = np.array([], dtype=np.float32).reshape(0, loc_vars_df.shape[1] if loc_vars_df.shape[1] > 0 else 0)
|
34 |
+
else:
|
35 |
+
loc_vars_np_float32 = np.ascontiguousarray(loc_vars_df.astype(np.float32).values)
|
36 |
+
|
37 |
+
# Initialize KMeans with algorithm="elkan" for potential speedup
|
38 |
+
# and fit on the float32 data.
|
39 |
+
self.kmeans = KMeans(
|
40 |
+
n_clusters=1000,
|
41 |
+
random_state=0,
|
42 |
+
n_init=10,
|
43 |
+
algorithm="elkan" # Added for speed optimization
|
44 |
+
).fit(loc_vars_np_float32)
|
45 |
+
|
46 |
+
# cluster_centers_ will be float32 if fitted on float32 data.
|
47 |
+
# Pass the same float32 NumPy array for distance calculations.
|
48 |
+
closest, _ = pairwise_distances_argmin_min(
|
49 |
+
self.kmeans.cluster_centers_,
|
50 |
+
loc_vars_np_float32
|
51 |
+
)
|
52 |
|
53 |
+
self.rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
|
54 |
+
self.rep_ids.name = 'policy_id'
|
55 |
+
self.rep_ids.index.name = 'cluster_id'
|
|
|
56 |
|
57 |
+
# policy_count is based on the number of items in the input data.
|
58 |
+
# Use loc_vars_np_float32.shape[0] which is the number of rows.
|
59 |
+
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * loc_vars_np_float32.shape[0]}))['policy_count']
|
60 |
|
61 |
def agg_by_cluster(self, df, agg=None):
|
62 |
"""Aggregate columns by cluster"""
|
|
|
68 |
|
69 |
def extract_reps(self, df):
|
70 |
"""Extract the rows of representative policies"""
|
71 |
+
# Ensure policy_id in df is of the same type as self.rep_ids if it's not already the index
|
72 |
+
# Typically, df here will have 'policy_id' as its index as per original data.
|
73 |
+
# If df's index is not 'policy_id', ensure 'policy_id' column exists and has compatible type.
|
74 |
+
current_df_index_name = df.index.name
|
75 |
+
# If 'policy_id' is not the index, reset it. Otherwise, use the index.
|
76 |
+
if 'policy_id' not in df.columns and df.index.name != 'policy_id':
|
77 |
+
# This case should ideally not happen if inputs are consistent
|
78 |
+
# Forcing index to be named 'policy_id' if it's the policy identifier
|
79 |
+
df_indexed = df.copy()
|
80 |
+
if df_indexed.index.name is None: # Or some other logic to identify the policy_id column
|
81 |
+
gr.Warning("DataFrame passed to extract_reps has no index name, assuming index is policy_id.")
|
82 |
+
df_indexed.index.name = 'policy_id'
|
83 |
+
|
84 |
+
temp = pd.merge(self.rep_ids, df_indexed.reset_index(), how='left', on='policy_id')
|
85 |
+
|
86 |
+
elif 'policy_id' in df.columns and df.index.name == 'policy_id' and df.index.name in df.columns: # if policy_id is both index and a column
|
87 |
+
temp = pd.merge(self.rep_ids, df, how='left', on='policy_id') # Merge on column if available
|
88 |
+
|
89 |
+
elif df.index.name == 'policy_id':
|
90 |
+
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
|
91 |
+
|
92 |
+
else: # 'policy_id' is a column, not the index
|
93 |
+
temp = pd.merge(self.rep_ids, df.reset_index(drop=df.index.name is None), how='left', on='policy_id')
|
94 |
+
|
95 |
+
|
96 |
+
temp.index.name = 'cluster_id' # The merge result's index is not cluster_id by default
|
97 |
+
temp = temp.set_index(self.rep_ids.index) # Set index to be cluster_id from self.rep_ids
|
98 |
+
return temp.drop('policy_id', axis=1, errors='ignore')
|
99 |
+
|
100 |
|
101 |
def extract_and_scale_reps(self, df, agg=None):
|
102 |
"""Extract and scale the rows of representative policies"""
|
103 |
+
extracted_df = self.extract_reps(df)
|
104 |
if agg:
|
105 |
+
cols = extracted_df.columns # Use columns from extracted_df
|
106 |
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
|
107 |
+
mult.index = extracted_df.index # Align index
|
|
|
|
|
108 |
return extracted_df.mul(mult)
|
109 |
else:
|
110 |
+
return extracted_df.mul(self.policy_count, axis=0)
|
111 |
|
112 |
def compare(self, df, agg=None):
|
113 |
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
|
|
|
118 |
def compare_total(self, df, agg=None):
|
119 |
"""Aggregate df by columns"""
|
120 |
if agg:
|
|
|
121 |
actual_values = {}
|
122 |
for col in df.columns:
|
123 |
if agg.get(col, 'sum') == 'mean':
|
|
|
126 |
actual_values[col] = df[col].sum()
|
127 |
actual = pd.Series(actual_values)
|
128 |
|
|
|
129 |
reps_unscaled = self.extract_reps(df)
|
130 |
estimate_values = {}
|
131 |
|
132 |
+
for col in df.columns: # Iterate over original df columns to ensure all are covered
|
133 |
+
if col not in reps_unscaled.columns: # Column might not be in reps_unscaled if it was dropped or not selected
|
134 |
+
if agg.get(col, 'sum') == 'mean':
|
135 |
+
estimate_values[col] = np.nan # Or some other placeholder like 0, or actual.get(col, 0)
|
136 |
+
else:
|
137 |
+
estimate_values[col] = 0
|
138 |
+
gr.Warning(f"Column '{col}' not found in representative policies output for 'compare_total'. Estimate will be 0/NaN.")
|
139 |
+
continue
|
140 |
+
|
141 |
if agg.get(col, 'sum') == 'mean':
|
|
|
142 |
weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
|
143 |
total_weight = self.policy_count.sum()
|
144 |
estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
|
|
|
151 |
actual = df.sum()
|
152 |
estimate = self.extract_and_scale_reps(df).sum()
|
153 |
|
154 |
+
# Ensure alignment for error calculation
|
155 |
+
actual, estimate = actual.align(estimate, fill_value=0)
|
156 |
error = np.where(actual != 0, estimate / actual - 1, 0)
|
157 |
|
158 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
159 |
|
160 |
|
161 |
+
# --- Plotting functions (plot_cashflows_comparison, plot_scatter_comparison) remain unchanged ---
|
162 |
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
163 |
"""Create cashflow comparison plots"""
|
164 |
if not cfs_list or not cluster_obj or not titles:
|
|
|
167 |
if num_plots == 0:
|
168 |
return None
|
169 |
|
|
|
170 |
cols = 2
|
171 |
rows = (num_plots + cols - 1) // cols
|
172 |
|
|
|
175 |
|
176 |
for i, (df, title) in enumerate(zip(cfs_list, titles)):
|
177 |
if i < len(axes):
|
178 |
+
# Ensure df passed to compare_total is appropriate.
|
179 |
+
# If df has policy_id as index, it matches expectations of downstream functions in Clusters.
|
180 |
+
# If not, ensure policy_id is a column or handle appropriately.
|
181 |
+
if df.index.name != 'policy_id' and 'policy_id' not in df.columns:
|
182 |
+
gr.Warning(f"DataFrame for plot '{title}' does not have 'policy_id' as index or column. Results may be incorrect.")
|
183 |
+
|
184 |
+
comparison = cluster_obj.compare_total(df.set_index('policy_id') if 'policy_id' in df.columns and df.index.name != 'policy_id' else df)
|
185 |
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
|
186 |
axes[i].set_xlabel('Time')
|
187 |
axes[i].set_ylabel('Value')
|
188 |
|
|
|
189 |
for j in range(i + 1, len(axes)):
|
190 |
fig.delaxes(axes[j])
|
191 |
|
|
|
200 |
def plot_scatter_comparison(df_compare_output, title):
|
201 |
"""Create scatter plot comparison from compare() output"""
|
202 |
if df_compare_output is None or df_compare_output.empty:
|
|
|
203 |
fig, ax = plt.subplots(figsize=(12, 8))
|
204 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
205 |
ax.set_title(title)
|
|
|
213 |
fig, ax = plt.subplots(figsize=(12, 8))
|
214 |
|
215 |
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
|
216 |
+
gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
|
217 |
+
ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
|
218 |
else:
|
219 |
unique_levels = df_compare_output.index.get_level_values(1).unique()
|
220 |
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
|
221 |
|
222 |
for item_level, color_val in zip(unique_levels, colors):
|
223 |
subset = df_compare_output.xs(item_level, level=1)
|
224 |
+
ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
|
225 |
+
if len(unique_levels) > 1 and len(unique_levels) <= 20: # Increased legend item limit slightly
|
226 |
+
ax.legend(title=str(df_compare_output.index.names[1]))
|
227 |
|
228 |
ax.set_xlabel('Actual')
|
229 |
ax.set_ylabel('Estimate')
|
230 |
ax.set_title(title)
|
231 |
ax.grid(True)
|
232 |
|
|
|
233 |
lims = [
|
234 |
+
np.nanmin([ax.get_xlim(), ax.get_ylim()]), # Use nanmin/nanmax
|
235 |
+
np.nanmax([ax.get_xlim(), ax.get_ylim()]),
|
236 |
]
|
237 |
+
if lims[0] != lims[1] and np.isfinite(lims[0]) and np.isfinite(lims[1]): # Check for valid limits
|
238 |
ax.plot(lims, lims, 'r-', linewidth=0.5)
|
239 |
ax.set_xlim(lims)
|
240 |
ax.set_ylim(lims)
|
|
|
246 |
plt.close(fig)
|
247 |
return img
|
248 |
|
249 |
+
# --- Main processing function (process_files) ---
|
250 |
+
# Ensure DataFrames passed to Clusters methods have 'policy_id' as index if expected.
|
251 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
252 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
253 |
"""Main processing function - now accepts file paths"""
|
254 |
try:
|
255 |
+
# Consider using engine='calamine' for faster Excel reading if available (pip install pandas[calamine])
|
256 |
+
# e.g., cfs = pd.read_excel(cashflow_base_path, index_col=0, engine='calamine')
|
257 |
cfs = pd.read_excel(cashflow_base_path, index_col=0)
|
258 |
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
|
259 |
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
|
260 |
|
261 |
pol_data_full = pd.read_excel(policy_data_path, index_col=0)
|
|
|
262 |
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
|
263 |
+
|
264 |
+
# Ensure index is named 'policy_id' if it's not already named, assuming index is the policy identifier
|
265 |
+
for df in [cfs, cfs_lapse50, cfs_mort15, pol_data_full]:
|
266 |
+
if df.index.name is None:
|
267 |
+
df.index.name = 'policy_id'
|
268 |
+
if 'policy_id' not in df.columns and df.index.name == 'policy_id': # Add policy_id as column if its only an index
|
269 |
+
df.reset_index(inplace=True) # this makes policy_id a column
|
270 |
+
df.set_index('policy_id', inplace=True) # and keeps it as index
|
271 |
+
|
272 |
+
if all(col in pol_data_full.columns or col == pol_data_full.index.name for col in required_cols):
|
273 |
+
# If policy_id is index, it won't be in columns. Adjust selection.
|
274 |
+
cols_to_select = [col for col in required_cols if col in pol_data_full.columns]
|
275 |
+
if pol_data_full.index.name in required_cols and pol_data_full.index.name not in cols_to_select:
|
276 |
+
# This case is tricky; if an ID is part of required_cols and is the index.
|
277 |
+
# For simplicity, assume required_cols are actual data columns.
|
278 |
+
pass # Let it proceed, it might be handled by selection or error later.
|
279 |
+
|
280 |
+
pol_data = pol_data_full[cols_to_select].copy() # Use .copy() to avoid SettingWithCopyWarning
|
281 |
+
# If 'policy_id' was the index and required, it's implicitly handled or needs specific logic.
|
282 |
+
# For K-Means, policy_id itself is usually not a feature.
|
283 |
else:
|
284 |
+
missing_req_cols = [col for col in required_cols if col not in pol_data_full.columns and col != pol_data_full.index.name]
|
285 |
+
gr.Warning(f"Policy data might be missing required columns: {missing_req_cols}. Found: {pol_data_full.columns.tolist()}")
|
286 |
+
pol_data = pol_data_full # Fallback, but ensure it's numeric for clustering/scaling
|
287 |
|
288 |
pvs = pd.read_excel(pv_base_path, index_col=0)
|
289 |
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
|
290 |
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
|
291 |
+
|
292 |
+
for df in [pvs, pvs_lapse50, pvs_mort15]:
|
293 |
+
if df.index.name is None:
|
294 |
+
df.index.name = 'policy_id'
|
295 |
+
if 'policy_id' not in df.columns and df.index.name == 'policy_id':
|
296 |
+
df.reset_index(inplace=True)
|
297 |
+
df.set_index('policy_id', inplace=True)
|
298 |
|
299 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
300 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
|
|
303 |
|
304 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
305 |
|
306 |
+
# DataFrames passed to Clusters should be policy_id indexed for .values to exclude it.
|
307 |
+
# Or, select only feature columns before passing.
|
308 |
+
# The Clusters class now expects a DataFrame and will use .values, so pass only feature columns.
|
309 |
+
# If index is policy_id, df.values will not include it. This is good.
|
310 |
+
|
311 |
# --- 1. Cashflow Calibration ---
|
312 |
+
# Ensure 'cfs' DataFrame does not include 'policy_id' when .values is called in Clusters
|
313 |
+
cluster_cfs = Clusters(cfs.reset_index().set_index('policy_id')) # Pass with policy_id as index
|
314 |
|
315 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
316 |
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
|
|
323 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
324 |
|
325 |
# --- 2. Policy Attribute Calibration ---
|
326 |
+
loc_vars_attrs = pd.DataFrame() # Initialize
|
327 |
+
if not pol_data.empty:
|
328 |
+
# Ensure pol_data is purely numeric for scaling and KMeans
|
329 |
+
numeric_pol_data = pol_data.select_dtypes(include=np.number)
|
330 |
+
if not numeric_pol_data.empty and not (numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True) == 0).all():
|
331 |
+
loc_vars_attrs = (numeric_pol_data - numeric_pol_data.min(numeric_only=True)) / \
|
332 |
+
(numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True))
|
333 |
+
loc_vars_attrs.index = numeric_pol_data.index # Preserve index
|
334 |
+
else:
|
335 |
+
gr.Warning("Policy data for attribute calibration is empty, non-numeric, or has no variance. Skipping attribute calibration content.")
|
336 |
+
loc_vars_attrs = numeric_pol_data # or an empty DataFrame with original index
|
337 |
else:
|
338 |
+
gr.Warning("Policy data is empty. Skipping attribute calibration content.")
|
339 |
+
|
|
|
340 |
if not loc_vars_attrs.empty:
|
341 |
+
cluster_attrs = Clusters(loc_vars_attrs.reset_index().set_index('policy_id')) # Pass with policy_id as index
|
342 |
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
|
343 |
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
|
344 |
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
|
|
|
348 |
results['attr_total_cf_base'] = pd.DataFrame()
|
349 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
350 |
results['attr_total_pv_base'] = pd.DataFrame()
|
351 |
+
results['attr_cashflow_plot'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data') # Empty plot
|
352 |
+
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data')
|
353 |
+
|
354 |
|
355 |
# --- 3. Present Value Calibration ---
|
356 |
+
cluster_pvs = Clusters(pvs.reset_index().set_index('policy_id')) # Pass with policy_id as index
|
357 |
|
358 |
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
|
359 |
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
|
|
|
366 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
367 |
|
368 |
# --- Summary Comparison Plot Data ---
|
|
|
|
|
369 |
error_data = {}
|
370 |
|
|
|
371 |
def get_error_safe(compare_result, col_name=None):
|
372 |
+
if compare_result is None or compare_result.empty or 'error' not in compare_result.columns: # Check if None
|
373 |
return np.nan
|
374 |
if col_name and col_name in compare_result.index:
|
375 |
return abs(compare_result.loc[col_name, 'error'])
|
376 |
else:
|
|
|
377 |
return abs(compare_result['error']).mean()
|
378 |
|
|
|
379 |
key_pv_col = None
|
380 |
+
# Use pvs.columns (which should be only feature columns after reset_index().set_index())
|
381 |
+
# Or, use the original pvs DataFrame if it's guaranteed to have the PV_NetCF column.
|
382 |
+
# For safety, check in the original pvs DataFrame which has not been stripped of columns.
|
383 |
+
original_pvs_cols = pd.read_excel(pv_base_path).columns # Quick read just for columns
|
384 |
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
|
385 |
+
if potential_col in original_pvs_cols: # Check against original columns
|
386 |
key_pv_col = potential_col
|
387 |
break
|
388 |
|
|
|
389 |
error_data['CF Calib.'] = [
|
390 |
+
get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
|
391 |
+
get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
|
392 |
+
get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
|
393 |
]
|
394 |
|
|
|
395 |
if not loc_vars_attrs.empty:
|
396 |
error_data['Attr Calib.'] = [
|
397 |
+
get_error_safe(results.get('attr_total_pv_base'), key_pv_col), # This was pvs, should be fine
|
398 |
+
get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), # Re-calculate for pvs_lapse50
|
399 |
+
get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) # Re-calculate for pvs_mort15
|
400 |
]
|
401 |
else:
|
402 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
403 |
|
|
|
404 |
error_data['PV Calib.'] = [
|
405 |
+
get_error_safe(results.get('pv_total_pv_base'), key_pv_col),
|
406 |
+
get_error_safe(results.get('pv_total_pv_lapse'), key_pv_col),
|
407 |
+
get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
|
408 |
]
|
409 |
|
|
|
410 |
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
|
411 |
|
412 |
fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
|
|
|
430 |
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
|
431 |
return {"error": f"File not found: {e.filename}"}
|
432 |
except KeyError as e:
|
433 |
+
# Check if the KeyError is from trying to access a column that became an index
|
434 |
+
gr.Error(f"A required column or index is missing or misnamed: {e}. Please check data format and ensure 'policy_id' is correctly handled as index for feature dataframes.")
|
435 |
+
return {"error": f"Missing column/index: {e}"}
|
436 |
except Exception as e:
|
437 |
+
import traceback
|
438 |
+
gr.Error(f"Error processing files: {str(e)}. Trace: {traceback.format_exc()}")
|
439 |
return {"error": f"Error processing files: {str(e)}"}
|
440 |
|
441 |
+
# --- Gradio interface creation (create_interface, etc.) remains unchanged ---
|
442 |
def create_interface():
|
443 |
with gr.Blocks(title="Cluster Model Points Analysis") as demo:
|
444 |
gr.Markdown("""
|
|
|
448 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
449 |
|
450 |
**Required Files (Excel .xlsx):**
|
451 |
+
- Cashflows - Base Scenario (index = policy_id, columns = time periods)
|
452 |
+
- Cashflows - Lapse Stress (+50%) (index = policy_id)
|
453 |
+
- Cashflows - Mortality Stress (+15%) (index = policy_id)
|
454 |
+
- Policy Data (index = policy_id, including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' as columns)
|
455 |
+
- Present Values - Base Scenario (index = policy_id, columns = PV components like 'PV_NetCF')
|
456 |
+
- Present Values - Lapse Stress (index = policy_id)
|
457 |
+
- Present Values - Mortality Stress (index = policy_id)
|
458 |
+
|
459 |
+
*Note: Ensure 'policy_id' is the index for all input files for correct processing.*
|
460 |
""")
|
461 |
|
462 |
with gr.Row():
|
|
|
503 |
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
504 |
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
505 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
506 |
+
with gr.Row(): # Changed to Row for consistency
|
507 |
+
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
|
508 |
+
# Added placeholders for other scenarios if they were intended
|
509 |
+
# attr_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
|
510 |
+
# attr_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
|
511 |
|
512 |
with gr.TabItem("💰 Present Value Calibration"):
|
513 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
|
|
544 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
545 |
|
546 |
file_paths = []
|
547 |
+
# Check if any FileData object is None (no file uploaded for a slot)
|
548 |
+
if any(f_obj is None for f_obj in files):
|
549 |
+
# Attempt to load from EXAMPLE_FILES if any input is missing
|
550 |
+
# This logic might be complex if mixing examples and uploads.
|
551 |
+
# For now, strict: all files must be present.
|
552 |
+
gr.Error("Missing file input for one or more fields. Please upload all required files or load the complete example dataset.")
|
553 |
+
return [None] * len(get_all_output_components())
|
554 |
+
|
555 |
for i, f_obj in enumerate(files):
|
556 |
+
# f_obj is TempFilePath (older Gradio) or FileData (newer) or str (from example load)
|
557 |
+
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Gradio FileData or similar
|
|
|
|
|
|
|
|
|
558 |
file_paths.append(f_obj.name)
|
559 |
+
elif isinstance(f_obj, str): # Path from example load
|
560 |
+
file_paths.append(f_obj)
|
561 |
+
else: # Should not happen if inputs are Files or paths
|
|
|
562 |
gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
|
563 |
return [None] * len(get_all_output_components())
|
564 |
|
565 |
results = process_files(*file_paths)
|
566 |
|
567 |
+
if "error" in results : # Check if process_files returned an error dict
|
568 |
+
# Error already shown by gr.Error in process_files
|
569 |
return [None] * len(get_all_output_components())
|
570 |
|
571 |
return [
|
|
|
592 |
|
593 |
# --- Action for Load Example Data Button ---
|
594 |
def load_example_files():
|
595 |
+
# Create dummy example files if they don't exist for demonstration if needed
|
596 |
+
# For this exercise, we assume they exist or user is warned.
|
597 |
+
os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True) # Ensure dir exists
|
598 |
+
|
599 |
+
missing_files = []
|
600 |
+
for key, fp in EXAMPLE_FILES.items():
|
601 |
+
if not os.path.exists(fp):
|
602 |
+
missing_files.append(fp)
|
603 |
+
# Create a minimal dummy Excel file if it's missing
|
604 |
+
try:
|
605 |
+
dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
|
606 |
+
if "cashflow" in key or "pv" in key: # Time series like
|
607 |
+
dummy_df_data = {'policy_id': [1,2,3], '0': [1,2,3], '1': [4,5,6]}
|
608 |
+
elif "policy_data" in key:
|
609 |
+
dummy_df_data = {'policy_id': [1,2,3], 'age_at_entry': [20,30,40], 'policy_term': [10,20,15],
|
610 |
+
'sum_assured': [1000,2000,1500], 'duration_mth': [5,10,7]}
|
611 |
+
|
612 |
+
dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
|
613 |
+
dummy_df.to_excel(fp)
|
614 |
+
gr.Warning(f"Example file '{fp}' was missing and a dummy file has been created. Results may not be meaningful.")
|
615 |
+
except Exception as e:
|
616 |
+
gr.Warning(f"Could not create dummy file for {fp}: {e}")
|
617 |
+
|
618 |
+
|
619 |
+
if missing_files and not all(os.path.exists(fp) for fp in EXAMPLE_FILES.values()): # Re-check after dummy creation attempt
|
620 |
+
# If still missing after trying to create dummies
|
621 |
+
gr.Error(f"Critical example data files are missing from '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist or check permissions.")
|
622 |
+
return [None] * 7 # Return None for all file inputs
|
623 |
|
624 |
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
|
625 |
+
# Return the string paths for the file components
|
626 |
return [
|
627 |
+
gr.File(value=EXAMPLE_FILES["cashflow_base"], Labeled_input=cashflow_base_input.label),
|
628 |
+
gr.File(value=EXAMPLE_FILES["cashflow_lapse"], Labeled_input=cashflow_lapse_input.label),
|
629 |
+
gr.File(value=EXAMPLE_FILES["cashflow_mort"], Labeled_input=cashflow_mort_input.label),
|
630 |
+
gr.File(value=EXAMPLE_FILES["policy_data"], Labeled_input=policy_data_input.label),
|
631 |
+
gr.File(value=EXAMPLE_FILES["pv_base"], Labeled_input=pv_base_input.label),
|
632 |
+
gr.File(value=EXAMPLE_FILES["pv_lapse"], Labeled_input=pv_lapse_input.label),
|
633 |
+
gr.File(value=EXAMPLE_FILES["pv_mort"], Labeled_input=pv_mort_input.label)
|
634 |
]
|
635 |
|
636 |
+
|
637 |
load_example_btn.click(
|
638 |
load_example_files,
|
639 |
inputs=[],
|
|
|
646 |
if __name__ == "__main__":
|
647 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
648 |
os.makedirs(EXAMPLE_DATA_DIR)
|
649 |
+
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or they will be generated as dummies.")
|
650 |
+
|
651 |
+
# Simple check and dummy file creation for example data if not present
|
652 |
+
for key, fp in EXAMPLE_FILES.items():
|
653 |
+
if not os.path.exists(fp):
|
654 |
+
print(f"Example file {fp} not found. Attempting to create a dummy file.")
|
655 |
+
try:
|
656 |
+
dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
|
657 |
+
if "cashflow" in key or "pv" in key:
|
658 |
+
dummy_df_data = {f'{i}':np.random.rand(3) for i in range(10)} # 10 time periods
|
659 |
+
dummy_df_data['policy_id'] = [f'P{j}' for j in range(3)]
|
660 |
+
elif "policy_data" in key:
|
661 |
+
dummy_df_data = {'policy_id': [f'P{j}' for j in range(3)],
|
662 |
+
'age_at_entry': np.random.randint(20, 50, 3),
|
663 |
+
'policy_term': np.random.randint(10, 30, 3),
|
664 |
+
'sum_assured': np.random.randint(10000, 50000, 3),
|
665 |
+
'duration_mth': np.random.randint(1, 120, 3)}
|
666 |
+
|
667 |
+
dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
|
668 |
+
dummy_df.to_excel(fp)
|
669 |
+
print(f"Dummy file for '{fp}' created.")
|
670 |
+
except Exception as e:
|
671 |
+
print(f"Could not create dummy file for {fp}: {e}")
|
672 |
+
|
673 |
|
674 |
demo_app = create_interface()
|
675 |
demo_app.launch()
|