Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,280 +5,139 @@ from sklearn.cluster import KMeans
|
|
5 |
from sklearn.metrics import r2_score, pairwise_distances_argmin_min
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
8 |
-
import os # For checking file extensions
|
9 |
|
10 |
def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
11 |
-
|
12 |
-
Performs cluster analysis for actuarial model point selection.
|
13 |
-
Accepts both Excel and CSV files.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
policy_file: Gradio File object for policy data.
|
17 |
-
cashflow_file: Gradio File object for cashflow data.
|
18 |
-
pv_file: Gradio File object for present value data.
|
19 |
-
num_clusters: Number of clusters (model points) to generate.
|
20 |
-
|
21 |
-
Returns:
|
22 |
-
A tuple: (csv_data_string, cashflow_plot_bytes, pv_plot_bytes, metrics_text_string)
|
23 |
-
Returns (None, None, None, error_message_string) if an error occurs.
|
24 |
-
"""
|
25 |
-
# Initialize outputs to None or empty strings for Gradio components
|
26 |
-
csv_data = ""
|
27 |
-
cashflow_plot = None
|
28 |
-
pv_plot = None
|
29 |
-
metrics_text = "Starting analysis..." # Initial status message
|
30 |
-
|
31 |
-
# --- Start of detailed logging ---
|
32 |
-
print("\n" + "="*50)
|
33 |
-
print(f"[{pd.Timestamp.now()}] --- cluster_analysis function called ---")
|
34 |
-
print(f"Received num_clusters: {num_clusters}")
|
35 |
-
print(f"Policy file received: {policy_file.name if policy_file else 'None'}")
|
36 |
-
print(f"Cashflow file received: {cashflow_file.name if cashflow_file else 'None'}")
|
37 |
-
print(f"PV file received: {pv_file.name if pv_file else 'None'}")
|
38 |
-
print("="*50 + "\n")
|
39 |
-
|
40 |
-
# Helper function to read files based on extension
|
41 |
-
def read_data_file(file_obj, index_col=None):
|
42 |
-
if file_obj is None:
|
43 |
-
raise ValueError("File object is None.")
|
44 |
-
|
45 |
-
file_path = file_obj.name
|
46 |
-
file_extension = os.path.splitext(file_path)[1].lower()
|
47 |
-
|
48 |
-
if file_extension in ['.xlsx', '.xls']:
|
49 |
-
print(f"Attempting to read Excel file: {file_path}")
|
50 |
-
return pd.read_excel(file_path, index_col=index_col)
|
51 |
-
elif file_extension == '.csv':
|
52 |
-
print(f"Attempting to read CSV file: {file_path}")
|
53 |
-
# Consider adding 'sep' argument if CSV delimiter is not comma, e.g., sep=';'
|
54 |
-
return pd.read_csv(file_path, index_col=index_col)
|
55 |
-
else:
|
56 |
-
raise ValueError(f"Unsupported file type: {file_extension}. Please upload .xlsx, .xls, or .csv files.")
|
57 |
-
|
58 |
-
# 1. Basic checks and file reading
|
59 |
try:
|
60 |
-
|
61 |
-
|
62 |
-
cashflow_df =
|
63 |
-
pv_df =
|
64 |
-
print(f"[{pd.Timestamp.now()}] Files read successfully.")
|
65 |
-
print(f"Policy data shape: {policy_df.shape}, Columns: {policy_df.columns.tolist()}")
|
66 |
-
print(f"Cashflow data shape: {cashflow_df.shape}, Index type: {cashflow_df.index.dtype}")
|
67 |
-
print(f"PV data shape: {pv_df.shape}, Index type: {pv_df.index.dtype}")
|
68 |
-
|
69 |
except Exception as e:
|
70 |
-
|
71 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
72 |
-
return (None, None, None, error_msg)
|
73 |
|
74 |
-
#
|
|
|
75 |
required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
|
76 |
-
# Strip whitespace from column names for robust matching
|
77 |
-
policy_df.columns = policy_df.columns.str.strip()
|
78 |
-
|
79 |
if not all(col in policy_df.columns for col in required_cols):
|
80 |
-
|
81 |
-
|
82 |
-
f"Found: {found_cols}. Please check your policy data column headers for typos or extra spaces.")
|
83 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
84 |
-
return (None, None, None, error_msg)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
96 |
-
return (None, None, None, error_msg)
|
97 |
|
98 |
-
#
|
99 |
try:
|
100 |
-
|
101 |
-
if num_clusters <= 1:
|
102 |
-
error_msg = "Number of clusters must be at least 2."
|
103 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
104 |
-
return (None, None, None, error_msg)
|
105 |
-
if num_clusters > n_samples:
|
106 |
-
original_num_clusters = num_clusters
|
107 |
-
num_clusters = n_samples # Adjust if clusters > samples
|
108 |
-
print(f"[{pd.Timestamp.now()}] Warning: Number of clusters ({original_num_clusters}) "
|
109 |
-
f"exceeded number of samples ({n_samples}). Reduced to {num_clusters}.")
|
110 |
-
|
111 |
-
# Use 'auto' for n_init for newer scikit-learn versions
|
112 |
-
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
|
113 |
kmeans.fit(X_scaled)
|
114 |
policy_df['Cluster'] = kmeans.labels_
|
115 |
-
print(f"[{pd.Timestamp.now()}] Clustering successful with {num_clusters} clusters.")
|
116 |
except Exception as e:
|
117 |
-
|
118 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
119 |
-
return (None, None, None, error_msg)
|
120 |
|
121 |
-
#
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
return (None, None, None,
|
135 |
|
136 |
-
#
|
137 |
csv_buffer = io.StringIO()
|
138 |
-
model_points.to_csv(csv_buffer
|
139 |
csv_data = csv_buffer.getvalue()
|
140 |
-
print(f"[{pd.Timestamp.now()}] Model points CSV generated.")
|
141 |
|
142 |
-
#
|
143 |
-
|
144 |
-
|
145 |
-
missing_cf_indices = [idx for idx in model_points.index if idx not in cashflow_df.index]
|
146 |
-
if missing_cf_indices:
|
147 |
-
raise KeyError(f"Cashflow data is missing entries for model point indices: {missing_cf_indices[:5]}... Please check Cashflow data's first column (index).")
|
148 |
-
|
149 |
-
proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
|
150 |
-
seriatim_cashflows = cashflow_df.sum()
|
151 |
-
print(f"[{pd.Timestamp.now()}] Cashflows aggregated.")
|
152 |
-
except KeyError as e:
|
153 |
-
error_msg = (f"Key Error during cashflow aggregation. "
|
154 |
-
f"Ensure the first column of your Cashflow Excel/CSV file contains policy IDs "
|
155 |
-
f"that match the indices from your Policy data: {e}")
|
156 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
157 |
-
return (None, None, None, error_msg)
|
158 |
-
except Exception as e:
|
159 |
-
error_msg = f"Error aggregating cashflows: {e}"
|
160 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
161 |
-
return (None, None, None, error_msg)
|
162 |
-
|
163 |
-
# 8. Plot aggregated cashflows
|
164 |
-
try:
|
165 |
-
fig, ax = plt.subplots(figsize=(10, 5)) # Slightly larger plot
|
166 |
-
seriatim_cashflows.plot(ax=ax, label='Seriatim Cashflows', marker='.', linestyle='-')
|
167 |
-
proxy_cashflows.plot(ax=ax, label='Proxy Cashflows', marker='x', linestyle='--')
|
168 |
-
ax.set_title('Aggregated Cashflows Comparison')
|
169 |
-
ax.set_xlabel('Period')
|
170 |
-
ax.set_ylabel('Cashflow Amount')
|
171 |
-
ax.legend()
|
172 |
-
ax.grid(True, linestyle=':', alpha=0.7)
|
173 |
-
plt.tight_layout() # Adjust layout to prevent labels from overlapping
|
174 |
-
|
175 |
-
buf = io.BytesIO()
|
176 |
-
plt.savefig(buf, format='png')
|
177 |
-
plt.close(fig)
|
178 |
-
buf.seek(0)
|
179 |
-
cashflow_plot = buf.read()
|
180 |
-
print(f"[{pd.Timestamp.now()}] Cashflow plot generated.")
|
181 |
-
except Exception as e:
|
182 |
-
error_msg = f"Error generating cashflow plot: {e}"
|
183 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
184 |
-
return (None, None, None, error_msg)
|
185 |
-
|
186 |
-
# 9. Aggregate present values
|
187 |
-
try:
|
188 |
-
missing_pv_indices = [idx for idx in model_points.index if idx not in pv_df.index]
|
189 |
-
if missing_pv_indices:
|
190 |
-
raise KeyError(f"PV data is missing entries for model point indices: {missing_pv_indices[:5]}... Please check PV data's first column (index).")
|
191 |
-
|
192 |
-
# Assuming PV data has only one column or the relevant column is the first one
|
193 |
-
proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
|
194 |
-
seriatim_pv = pv_df.sum().values[0] # Assuming total PV is in the first column
|
195 |
-
print(f"[{pd.Timestamp.now()}] Present Values aggregated.")
|
196 |
-
except KeyError as e:
|
197 |
-
error_msg = (f"Key Error during PV aggregation. "
|
198 |
-
f"Ensure the first column of your PV Excel/CSV file contains policy IDs "
|
199 |
-
f"that match the indices from your Policy data: {e}")
|
200 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
201 |
-
return (None, None, None, error_msg)
|
202 |
-
except Exception as e:
|
203 |
-
error_msg = f"Error aggregating present values: {e}"
|
204 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
205 |
-
return (None, None, None, error_msg)
|
206 |
-
|
207 |
-
# 10. Present Value comparison plot (bar)
|
208 |
-
try:
|
209 |
-
fig2, ax2 = plt.subplots(figsize=(6, 5)) # Adjust size
|
210 |
-
ax2.bar(['Seriatim PV', 'Proxy PV'], [seriatim_pv, proxy_pv], color=['#1f77b4', '#ff7f0e']) # Use nicer colors
|
211 |
-
ax2.set_title('Aggregated Present Values Comparison')
|
212 |
-
ax2.set_ylabel('Present Value')
|
213 |
-
ax2.grid(axis='y', linestyle=':', alpha=0.7)
|
214 |
-
plt.tight_layout()
|
215 |
-
|
216 |
-
buf2 = io.BytesIO()
|
217 |
-
plt.savefig(buf2, format='png')
|
218 |
-
plt.close(fig2)
|
219 |
-
buf2.seek(0)
|
220 |
-
pv_plot = buf2.read()
|
221 |
-
print(f"[{pd.Timestamp.now()}] PV plot generated.")
|
222 |
-
except Exception as e:
|
223 |
-
error_msg = f"Error generating PV plot: {e}"
|
224 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
225 |
-
return (None, None, None, error_msg)
|
226 |
-
|
227 |
-
# 11. Calculate Accuracy metrics
|
228 |
-
try:
|
229 |
-
common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
|
230 |
-
if common_idx.empty:
|
231 |
-
r2 = float('nan') # Cannot compute R2 if no common cashflow periods
|
232 |
-
print(f"[{pd.Timestamp.now()}] Warning: No common indices for R-squared calculation.")
|
233 |
-
else:
|
234 |
-
r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
|
235 |
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
print(f"[{pd.Timestamp.now()}] Accuracy metrics calculated.")
|
244 |
-
except Exception as e:
|
245 |
-
error_msg = f"Error calculating accuracy metrics: {e}"
|
246 |
-
print(f"[{pd.Timestamp.now()}] {error_msg}")
|
247 |
-
return (None, None, None, error_msg)
|
248 |
|
249 |
-
print(f"[{pd.Timestamp.now()}] --- cluster_analysis completed successfully ---")
|
250 |
return csv_data, cashflow_plot, pv_plot, metrics_text
|
251 |
|
252 |
-
# --- Gradio Interface ---
|
253 |
with gr.Blocks() as demo:
|
254 |
-
gr.Markdown("# Actuarial Model Point Selection")
|
255 |
-
gr.Markdown("""
|
256 |
-
This application performs cluster analysis on policy data to select representative model points.
|
257 |
-
It then aggregates cashflows and present values based on these model points and compares them
|
258 |
-
to the seriatim (full portfolio) results, providing accuracy metrics and visualizations.
|
259 |
-
|
260 |
-
**Instructions:**
|
261 |
-
1. **Upload Policy Data (Excel or CSV file):** Ensure it contains columns named exactly `IssueAge`, `PolicyTerm`, `SumAssured`, and `Duration`. **Crucially, double-check for leading/trailing spaces in column names.** The first column can be a unique policy identifier (though not explicitly used for clustering, it helps with index matching).
|
262 |
-
2. **Upload Cashflow Data (Excel or CSV file):** The **first column** of this file must be a unique policy identifier (e.g., `policy_id`), and this column will be used as the DataFrame index. The remaining columns should represent cashflow periods (e.g., `CF_Year_1`, `CF_Year_2`).
|
263 |
-
3. **Upload Present Value Data (Excel or CSV file):** The **first column** of this file must also be a unique policy identifier, matching the policy data's identifiers. The second column should contain the present value for each policy.
|
264 |
-
4. Adjust the 'Number of Model Points' using the slider.
|
265 |
-
5. Click 'Run Clustering'.
|
266 |
-
""")
|
267 |
|
268 |
with gr.Row():
|
269 |
with gr.Column():
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
run_btn = gr.Button("Run Clustering", variant="primary")
|
276 |
|
277 |
with gr.Column():
|
278 |
-
output_csv = gr.Textbox(label="Model Points CSV Output
|
279 |
-
cashflow_img = gr.Image(label="Aggregated Cashflows Comparison",
|
280 |
-
pv_img = gr.Image(label="Aggregated Present Values Comparison",
|
281 |
-
metrics_box = gr.Textbox(label="Accuracy Metrics
|
282 |
|
283 |
run_btn.click(
|
284 |
cluster_analysis,
|
@@ -286,6 +145,5 @@ with gr.Blocks() as demo:
|
|
286 |
outputs=[output_csv, cashflow_img, pv_img, metrics_box]
|
287 |
)
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
demo.launch(debug=True)
|
|
|
5 |
from sklearn.metrics import r2_score, pairwise_distances_argmin_min
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
|
|
8 |
|
9 |
def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
|
10 |
+
# Basic checks and reads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
try:
|
12 |
+
# Use policy_file.name which is the path to the temporary file Gradio creates
|
13 |
+
policy_df = pd.read_csv(policy_file.name, index_col=0)
|
14 |
+
cashflow_df = pd.read_csv(cashflow_file.name, index_col=0)
|
15 |
+
pv_df = pd.read_csv(pv_file.name, index_col=0)
|
|
|
|
|
|
|
|
|
|
|
16 |
except Exception as e:
|
17 |
+
return (None, None, None, f"Error reading CSV files: {e}. Ensure files are CSVs and the first column is the index (e.g., Policy ID).")
|
|
|
|
|
18 |
|
19 |
+
# Use policy attributes for clustering
|
20 |
+
# Ensure these column names match your policy data CSV
|
21 |
required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
|
|
|
|
|
|
|
22 |
if not all(col in policy_df.columns for col in required_cols):
|
23 |
+
missing_cols = [col for col in required_cols if col not in policy_df.columns]
|
24 |
+
return (None, None, None, f"Policy data missing required columns: {missing_cols}. Please ensure your policy CSV has these columns.")
|
|
|
|
|
|
|
25 |
|
26 |
+
X = policy_df[required_cols].fillna(0) # Simple imputation
|
27 |
+
|
28 |
+
# Handle cases with zero standard deviation (e.g., if a column has all same values after fillna)
|
29 |
+
X_std = X.std()
|
30 |
+
if (X_std == 0).any():
|
31 |
+
zero_std_cols = X_std[X_std == 0].index.tolist()
|
32 |
+
return (None, None, None, f"Error: Columns {zero_std_cols} have zero standard deviation after fillna(0). Cannot scale these columns. Please check your data.")
|
33 |
+
|
34 |
+
X_scaled = (X - X.mean()) / X_std
|
|
|
|
|
35 |
|
36 |
+
# Cluster
|
37 |
try:
|
38 |
+
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
kmeans.fit(X_scaled)
|
40 |
policy_df['Cluster'] = kmeans.labels_
|
|
|
41 |
except Exception as e:
|
42 |
+
return (None, None, None, f"Clustering error: {e}")
|
|
|
|
|
43 |
|
44 |
+
# Select model points as closest to cluster centers
|
45 |
+
centers = kmeans.cluster_centers_
|
46 |
+
closest, _ = pairwise_distances_argmin_min(centers, X_scaled)
|
47 |
+
model_points = policy_df.iloc[closest].copy()
|
48 |
+
|
49 |
+
# Calculate weights (count per cluster)
|
50 |
+
counts = policy_df['Cluster'].value_counts()
|
51 |
+
model_points['Weight'] = model_points['Cluster'].map(counts)
|
52 |
+
|
53 |
+
# Ensure model_points.index are valid for cashflow_df and pv_df
|
54 |
+
if not model_points.index.isin(cashflow_df.index).all():
|
55 |
+
return (None, None, None, "Error: Model point indices not found in cashflow data. Ensure Policy IDs match.")
|
56 |
+
if not model_points.index.isin(pv_df.index).all():
|
57 |
+
return (None, None, None, "Error: Model point indices not found in PV data. Ensure Policy IDs match.")
|
58 |
|
59 |
+
# Create CSV for download
|
60 |
csv_buffer = io.StringIO()
|
61 |
+
model_points.to_csv(csv_buffer) # index=True by default, which is good if index is PolicyID
|
62 |
csv_data = csv_buffer.getvalue()
|
|
|
63 |
|
64 |
+
# Aggregate cashflows weighted by cluster counts
|
65 |
+
# Ensure model_points['Weight'] is numeric for multiplication
|
66 |
+
model_points['Weight'] = pd.to_numeric(model_points['Weight'], errors='coerce').fillna(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
proxy_cashflows_df = cashflow_df.loc[model_points.index]
|
69 |
+
proxy_cashflows = proxy_cashflows_df.multiply(model_points['Weight'].values, axis=0).sum()
|
70 |
+
seriatim_cashflows = cashflow_df.sum()
|
71 |
+
|
72 |
+
# Plot aggregated cashflows
|
73 |
+
fig, ax = plt.subplots(figsize=(8,4))
|
74 |
+
seriatim_cashflows.plot(ax=ax, label='Seriatim Cashflows')
|
75 |
+
proxy_cashflows.plot(ax=ax, label='Proxy Cashflows', linestyle='--')
|
76 |
+
ax.set_title('Aggregated Cashflows Comparison')
|
77 |
+
ax.legend()
|
78 |
+
ax.grid(True)
|
79 |
+
plt.tight_layout()
|
80 |
+
buf = io.BytesIO()
|
81 |
+
plt.savefig(buf, format='png')
|
82 |
+
plt.close(fig)
|
83 |
+
buf.seek(0)
|
84 |
+
cashflow_plot = buf.read()
|
85 |
+
|
86 |
+
# Aggregate present values weighted
|
87 |
+
proxy_pv_df = pv_df.loc[model_points.index]
|
88 |
+
# Assuming pv_df has one column of PVs, or sum all columns if multiple
|
89 |
+
if proxy_pv_df.shape[1] > 1:
|
90 |
+
proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().sum()
|
91 |
+
seriatim_pv = pv_df.sum().sum()
|
92 |
+
else:
|
93 |
+
proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().iloc[0]
|
94 |
+
seriatim_pv = pv_df.sum().iloc[0]
|
95 |
+
|
96 |
+
|
97 |
+
# Present Value comparison plot (bar)
|
98 |
+
fig2, ax2 = plt.subplots(figsize=(5,4))
|
99 |
+
ax2.bar(['Seriatim PV', 'Proxy PV'], [seriatim_pv, proxy_pv], color=['blue', 'orange'])
|
100 |
+
ax2.set_title('Aggregated Present Values')
|
101 |
+
ax2.grid(axis='y')
|
102 |
+
plt.tight_layout()
|
103 |
+
buf2 = io.BytesIO()
|
104 |
+
plt.savefig(buf2, format='png')
|
105 |
+
plt.close(fig2)
|
106 |
+
buf2.seek(0)
|
107 |
+
pv_plot = buf2.read()
|
108 |
+
|
109 |
+
# Accuracy metrics
|
110 |
+
common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
|
111 |
+
if not common_idx.empty:
|
112 |
+
r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
|
113 |
+
else:
|
114 |
+
r2 = float('nan') # Or handle as error
|
115 |
+
|
116 |
+
pv_error = abs(proxy_pv - seriatim_pv) / seriatim_pv * 100 if seriatim_pv != 0 else float('inf')
|
117 |
|
118 |
+
metrics_text = (
|
119 |
+
f"R-squared for aggregated cashflows: {r2:.4f}\n"
|
120 |
+
f"Absolute percentage error in present value: {pv_error:.4f}%"
|
121 |
+
)
|
|
|
|
|
|
|
|
|
|
|
122 |
|
|
|
123 |
return csv_data, cashflow_plot, pv_plot, metrics_text
|
124 |
|
|
|
125 |
with gr.Blocks() as demo:
|
126 |
+
gr.Markdown("# Actuarial Model Point Selection (CSV Upload)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
with gr.Row():
|
129 |
with gr.Column():
|
130 |
+
policy_input = gr.File(label="Upload Policy Data (CSV with PolicyID as first column)")
|
131 |
+
cashflow_input = gr.File(label="Upload Cashflow Data (CSV with PolicyID as first column)")
|
132 |
+
pv_input = gr.File(label="Upload Present Value Data (CSV with PolicyID as first column)")
|
133 |
+
clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="Number of Model Points")
|
134 |
+
run_btn = gr.Button("Run Clustering")
|
|
|
135 |
|
136 |
with gr.Column():
|
137 |
+
output_csv = gr.Textbox(label="Model Points CSV Output", lines=10, interactive=False)
|
138 |
+
cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", type="pil") # Using PIL for better compatibility
|
139 |
+
pv_img = gr.Image(label="Aggregated Present Values Comparison", type="pil")
|
140 |
+
metrics_box = gr.Textbox(label="Accuracy Metrics", lines=4, interactive=False)
|
141 |
|
142 |
run_btn.click(
|
143 |
cluster_analysis,
|
|
|
145 |
outputs=[output_csv, cashflow_img, pv_img, metrics_box]
|
146 |
)
|
147 |
|
148 |
+
if __name__ == '__main__':
|
149 |
+
demo.launch(debug=True)
|
|