alidenewade commited on
Commit
74d9c0e
·
verified ·
1 Parent(s): ee9fc3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -127
app.py CHANGED
@@ -1,149 +1,359 @@
1
  import gradio as gr
2
- import pandas as pd
3
  import numpy as np
 
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import r2_score, pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
 
7
  import io
 
 
8
 
9
- def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
10
- # Basic checks and reads
11
- try:
12
- # Use policy_file.name which is the path to the temporary file Gradio creates
13
- policy_df = pd.read_csv(policy_file.name, index_col=0)
14
- cashflow_df = pd.read_csv(cashflow_file.name, index_col=0)
15
- pv_df = pd.read_csv(pv_file.name, index_col=0)
16
- except Exception as e:
17
- return (None, None, None, f"Error reading CSV files: {e}. Ensure files are CSVs and the first column is the index (e.g., Policy ID).")
18
-
19
- # Use policy attributes for clustering
20
- # Ensure these column names match your policy data CSV
21
- required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
22
- if not all(col in policy_df.columns for col in required_cols):
23
- missing_cols = [col for col in required_cols if col not in policy_df.columns]
24
- return (None, None, None, f"Policy data missing required columns: {missing_cols}. Please ensure your policy CSV has these columns.")
25
-
26
- X = policy_df[required_cols].fillna(0) # Simple imputation
27
-
28
- # Handle cases with zero standard deviation (e.g., if a column has all same values after fillna)
29
- X_std = X.std()
30
- if (X_std == 0).any():
31
- zero_std_cols = X_std[X_std == 0].index.tolist()
32
- return (None, None, None, f"Error: Columns {zero_std_cols} have zero standard deviation after fillna(0). Cannot scale these columns. Please check your data.")
33
 
34
- X_scaled = (X - X.mean()) / X_std
35
 
36
- # Cluster
37
- try:
38
- kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init=10)
39
- kmeans.fit(X_scaled)
40
- policy_df['Cluster'] = kmeans.labels_
41
- except Exception as e:
42
- return (None, None, None, f"Clustering error: {e}")
43
 
44
- # Select model points as closest to cluster centers
45
- centers = kmeans.cluster_centers_
46
- closest, _ = pairwise_distances_argmin_min(centers, X_scaled)
47
- model_points = policy_df.iloc[closest].copy()
 
48
 
49
- # Calculate weights (count per cluster)
50
- counts = policy_df['Cluster'].value_counts()
51
- model_points['Weight'] = model_points['Cluster'].map(counts)
52
-
53
- # Ensure model_points.index are valid for cashflow_df and pv_df
54
- if not model_points.index.isin(cashflow_df.index).all():
55
- return (None, None, None, "Error: Model point indices not found in cashflow data. Ensure Policy IDs match.")
56
- if not model_points.index.isin(pv_df.index).all():
57
- return (None, None, None, "Error: Model point indices not found in PV data. Ensure Policy IDs match.")
58
 
59
- # Create CSV for download
60
- csv_buffer = io.StringIO()
61
- model_points.to_csv(csv_buffer) # index=True by default, which is good if index is PolicyID
62
- csv_data = csv_buffer.getvalue()
 
63
 
64
- # Aggregate cashflows weighted by cluster counts
65
- # Ensure model_points['Weight'] is numeric for multiplication
66
- model_points['Weight'] = pd.to_numeric(model_points['Weight'], errors='coerce').fillna(1)
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- proxy_cashflows_df = cashflow_df.loc[model_points.index]
69
- proxy_cashflows = proxy_cashflows_df.multiply(model_points['Weight'].values, axis=0).sum()
70
- seriatim_cashflows = cashflow_df.sum()
 
71
 
72
- # Plot aggregated cashflows
73
- fig, ax = plt.subplots(figsize=(8,4))
74
- seriatim_cashflows.plot(ax=ax, label='Seriatim Cashflows')
75
- proxy_cashflows.plot(ax=ax, label='Proxy Cashflows', linestyle='--')
76
- ax.set_title('Aggregated Cashflows Comparison')
77
- ax.legend()
78
- ax.grid(True)
79
- plt.tight_layout()
80
  buf = io.BytesIO()
81
- plt.savefig(buf, format='png')
82
- plt.close(fig)
83
  buf.seek(0)
84
- cashflow_plot = buf.read()
85
-
86
- # Aggregate present values weighted
87
- proxy_pv_df = pv_df.loc[model_points.index]
88
- # Assuming pv_df has one column of PVs, or sum all columns if multiple
89
- if proxy_pv_df.shape[1] > 1:
90
- proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().sum()
91
- seriatim_pv = pv_df.sum().sum()
92
- else:
93
- proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().iloc[0]
94
- seriatim_pv = pv_df.sum().iloc[0]
95
-
96
 
97
- # Present Value comparison plot (bar)
98
- fig2, ax2 = plt.subplots(figsize=(5,4))
99
- ax2.bar(['Seriatim PV', 'Proxy PV'], [seriatim_pv, proxy_pv], color=['blue', 'orange'])
100
- ax2.set_title('Aggregated Present Values')
101
- ax2.grid(axis='y')
 
 
 
 
 
102
  plt.tight_layout()
103
- buf2 = io.BytesIO()
104
- plt.savefig(buf2, format='png')
105
- plt.close(fig2)
106
- buf2.seek(0)
107
- pv_plot = buf2.read()
108
-
109
- # Accuracy metrics
110
- common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
111
- if not common_idx.empty:
112
- r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
113
- else:
114
- r2 = float('nan') # Or handle as error
115
-
116
- pv_error = abs(proxy_pv - seriatim_pv) / seriatim_pv * 100 if seriatim_pv != 0 else float('inf')
117
-
118
- metrics_text = (
119
- f"R-squared for aggregated cashflows: {r2:.4f}\n"
120
- f"Absolute percentage error in present value: {pv_error:.4f}%"
121
- )
122
-
123
- return csv_data, cashflow_plot, pv_plot, metrics_text
124
-
125
- with gr.Blocks() as demo:
126
- gr.Markdown("# Actuarial Model Point Selection (CSV Upload)")
127
 
128
- with gr.Row():
129
- with gr.Column():
130
- policy_input = gr.File(label="Upload Policy Data (CSV with PolicyID as first column)")
131
- cashflow_input = gr.File(label="Upload Cashflow Data (CSV with PolicyID as first column)")
132
- pv_input = gr.File(label="Upload Present Value Data (CSV with PolicyID as first column)")
133
- clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="Number of Model Points")
134
- run_btn = gr.Button("Run Clustering")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- with gr.Column():
137
- output_csv = gr.Textbox(label="Model Points CSV Output", lines=10, interactive=False)
138
- cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", type="pil") # Using PIL for better compatibility
139
- pv_img = gr.Image(label="Aggregated Present Values Comparison", type="pil")
140
- metrics_box = gr.Textbox(label="Accuracy Metrics", lines=4, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- run_btn.click(
143
- cluster_analysis,
144
- inputs=[policy_input, cashflow_input, pv_input, clusters_input],
145
- outputs=[output_csv, cashflow_img, pv_img, metrics_box]
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- if __name__ == '__main__':
149
- demo.launch(debug=True)
 
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
+ import matplotlib.cm
8
  import io
9
+ import base64
10
+ from PIL import Image
11
 
12
+ class Clusters:
13
+ def __init__(self, loc_vars):
14
+ self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
15
+ closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
16
+
17
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
18
+ rep_ids.name = 'policy_id'
19
+ rep_ids.index.name = 'cluster_id'
20
+ self.rep_ids = rep_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
23
 
24
+ def agg_by_cluster(self, df, agg=None):
25
+ """Aggregate columns by cluster"""
26
+ temp = df.copy()
27
+ temp['cluster_id'] = self.kmeans.labels_
28
+ temp = temp.set_index('cluster_id')
29
+ agg = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} if agg else "sum"
30
+ return temp.groupby(temp.index).agg(agg)
31
 
32
+ def extract_reps(self, df):
33
+ """Extract the rows of representative policies"""
34
+ temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
35
+ temp.index.name = 'cluster_id'
36
+ return temp.drop('policy_id', axis=1)
37
 
38
+ def extract_and_scale_reps(self, df, agg=None):
39
+ """Extract and scale the rows of representative policies"""
40
+ if agg:
41
+ cols = df.columns
42
+ mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
43
+ return self.extract_reps(df).mul(mult)
44
+ else:
45
+ return self.extract_reps(df).mul(self.policy_count, axis=0)
 
46
 
47
+ def compare(self, df, agg=None):
48
+ """Returns a multi-indexed Dataframe comparing actual and estimate"""
49
+ source = self.agg_by_cluster(df, agg)
50
+ target = self.extract_and_scale_reps(df, agg)
51
+ return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
52
 
53
+ def compare_total(self, df, agg=None):
54
+ """Aggregate df by columns"""
55
+ if agg:
56
+ cols = df.columns
57
+ op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
58
+ actual = df.agg(op)
59
+ estimate = self.extract_and_scale_reps(df, agg=op)
60
+
61
+ op = {k: ((lambda s: s.dot(self.policy_count) / self.policy_count.sum()) if v == 'mean' else v) for k, v in op.items()}
62
+ estimate = estimate.agg(op)
63
+ else:
64
+ actual = df.sum()
65
+ estimate = self.extract_and_scale_reps(df).sum()
66
+
67
+ return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
68
 
69
+ def create_plot(plot_func, *args, **kwargs):
70
+ """Helper function to create plots and return as image"""
71
+ plt.figure(figsize=(10, 6))
72
+ plot_func(*args, **kwargs)
73
 
74
+ # Save plot to bytes
 
 
 
 
 
 
 
75
  buf = io.BytesIO()
76
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
 
77
  buf.seek(0)
78
+ plt.close()
79
+
80
+ return Image.open(buf)
 
 
 
 
 
 
 
 
 
81
 
82
+ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
83
+ """Create cashflow comparison plots"""
84
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
85
+ axes = axes.flatten()
86
+
87
+ for i, (df, title) in enumerate(zip(cfs_list, titles)):
88
+ if i < len(axes):
89
+ comparison = cluster_obj.compare_total(df)
90
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
91
+
92
  plt.tight_layout()
93
+ buf = io.BytesIO()
94
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
95
+ buf.seek(0)
96
+ plt.close()
97
+
98
+ return Image.open(buf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ def plot_scatter_comparison(df, title):
101
+ """Create scatter plot comparison"""
102
+ plt.figure(figsize=(12, 8))
103
+
104
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(df.index.levels[1])))
105
+
106
+ for y, c in zip(df.index.levels[1], colors):
107
+ plt.scatter(df.xs(y, level=1)['actual'], df.xs(y, level=1)['estimate'],
108
+ color=c, s=9, alpha=0.6)
109
+
110
+ plt.xlabel('Actual')
111
+ plt.ylabel('Estimate')
112
+ plt.title(title)
113
+ plt.grid(True)
114
+
115
+ # Draw identity line
116
+ lims = [
117
+ np.min([plt.xlim(), plt.ylim()]),
118
+ np.max([plt.xlim(), plt.ylim()]),
119
+ ]
120
+ plt.plot(lims, lims, 'r-', linewidth=0.5)
121
+ plt.xlim(lims)
122
+ plt.ylim(lims)
123
+
124
+ buf = io.BytesIO()
125
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
126
+ buf.seek(0)
127
+ plt.close()
128
+
129
+ return Image.open(buf)
130
 
131
+ def process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
132
+ """Main processing function"""
133
+ try:
134
+ # Read uploaded files
135
+ cfs = pd.read_excel(cashflow_base.name, index_col=0)
136
+ cfs_lapse50 = pd.read_excel(cashflow_lapse.name, index_col=0)
137
+ cfs_mort15 = pd.read_excel(cashflow_mort.name, index_col=0)
138
+
139
+ pol_data = pd.read_excel(policy_data.name, index_col=0)
140
+ if pol_data.shape[1] > 4:
141
+ pol_data = pol_data[['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']]
142
+
143
+ pvs = pd.read_excel(pv_base.name, index_col=0)
144
+ pvs_lapse50 = pd.read_excel(pv_lapse.name, index_col=0)
145
+ pvs_mort15 = pd.read_excel(pv_mort.name, index_col=0)
146
+
147
+ cfs_list = [cfs, cfs_lapse50, cfs_mort15]
148
+ pvs_list = [pvs, pvs_lapse50, pvs_mort15]
149
+ scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
150
+
151
+ results = {}
152
+
153
+ # 1. Cashflow Calibration
154
+ cluster_cfs = Clusters(cfs)
155
+
156
+ # Cashflow comparison tables
157
+ results['cf_base_table'] = cluster_cfs.compare_total(cfs)
158
+ results['cf_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50)
159
+ results['cf_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
160
+
161
+ # Policy attributes analysis
162
+ mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean'}
163
+ results['cf_policy_attrs'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
164
+
165
+ # Present value analysis
166
+ results['cf_pv_base'] = cluster_cfs.compare_total(pvs)
167
+ results['cf_pv_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
168
+ results['cf_pv_mort'] = cluster_cfs.compare_total(pvs_mort15)
169
+
170
+ # Create plots for cashflow calibration
171
+ results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
172
+ results['cf_scatter_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calibration - Base Scenario')
173
+
174
+ # 2. Policy Attribute Calibration
175
+ loc_vars = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
176
+ cluster_attrs = Clusters(loc_vars)
177
+
178
+ results['attr_cf_base'] = cluster_attrs.compare_total(cfs)
179
+ results['attr_policy_attrs'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
180
+ results['attr_pv_base'] = cluster_attrs.compare_total(pvs)
181
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
182
+ results['attr_scatter_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attribute Calibration - Base Scenario')
183
+
184
+ # 3. Present Value Calibration
185
+ cluster_pvs = Clusters(pvs)
186
+
187
+ results['pv_cf_base'] = cluster_pvs.compare_total(cfs)
188
+ results['pv_policy_attrs'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
189
+ results['pv_pv_base'] = cluster_pvs.compare_total(pvs)
190
+ results['pv_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
191
+ results['pv_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
192
+ results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
193
+ results['pv_scatter_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'Present Value Calibration - Base Scenario')
194
+
195
+ # Summary comparison plot
196
+ fig, ax = plt.subplots(figsize=(12, 8))
197
+ comparison_data = {
198
+ 'Cashflow Calibration': [
199
+ abs(cluster_cfs.compare_total(cfs)['error'].mean()),
200
+ abs(cluster_cfs.compare_total(pvs)['error'].mean())
201
+ ],
202
+ 'Policy Attribute Calibration': [
203
+ abs(cluster_attrs.compare_total(cfs)['error'].mean()),
204
+ abs(cluster_attrs.compare_total(pvs)['error'].mean())
205
+ ],
206
+ 'Present Value Calibration': [
207
+ abs(cluster_pvs.compare_total(cfs)['error'].mean()),
208
+ abs(cluster_pvs.compare_total(pvs)['error'].mean())
209
+ ]
210
+ }
211
+
212
+ x = np.arange(2)
213
+ width = 0.25
214
+
215
+ ax.bar(x - width, comparison_data['Cashflow Calibration'], width, label='Cashflow Calibration')
216
+ ax.bar(x, comparison_data['Policy Attribute Calibration'], width, label='Policy Attribute Calibration')
217
+ ax.bar(x + width, comparison_data['Present Value Calibration'], width, label='Present Value Calibration')
218
+
219
+ ax.set_ylabel('Mean Absolute Error')
220
+ ax.set_title('Calibration Method Comparison')
221
+ ax.set_xticks(x)
222
+ ax.set_xticklabels(['Cashflows', 'Present Values'])
223
+ ax.legend()
224
+ ax.grid(True, alpha=0.3)
225
+
226
+ buf = io.BytesIO()
227
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
228
+ buf.seek(0)
229
+ plt.close()
230
+ results['summary_plot'] = Image.open(buf)
231
+
232
+ return results
233
+
234
+ except Exception as e:
235
+ return {"error": f"Error processing files: {str(e)}"}
236
 
237
+ def create_interface():
238
+ with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Soft()) as demo:
239
+ gr.Markdown("""
240
+ # Cluster Model Points Analysis
241
+
242
+ This application applies cluster analysis to model point selection for insurance portfolios.
243
+ Upload your Excel files to analyze cashflows, policy attributes, and present values using different calibration methods.
244
+
245
+ **Required Files:**
246
+ - 3 Cashflow files (Base, Lapse stress, Mortality stress scenarios)
247
+ - 1 Policy data file
248
+ - 3 Present value files (Base, Lapse stress, Mortality stress scenarios)
249
+ """)
250
+
251
+ with gr.Row():
252
+ with gr.Column():
253
+ gr.Markdown("### Upload Files")
254
+ cashflow_base = gr.File(label="Cashflows - Base Scenario", file_types=[".xlsx"])
255
+ cashflow_lapse = gr.File(label="Cashflows - Lapse Stress (+50%)", file_types=[".xlsx"])
256
+ cashflow_mort = gr.File(label="Cashflows - Mortality Stress (+15%)", file_types=[".xlsx"])
257
+ policy_data = gr.File(label="Policy Data", file_types=[".xlsx"])
258
+ pv_base = gr.File(label="Present Values - Base Scenario", file_types=[".xlsx"])
259
+ pv_lapse = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
260
+ pv_mort = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
261
+
262
+ analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
263
+
264
+ with gr.Tabs():
265
+ with gr.TabItem("Summary"):
266
+ summary_plot = gr.Image(label="Calibration Methods Comparison")
267
+
268
+ with gr.TabItem("Cashflow Calibration"):
269
+ gr.Markdown("### Results using Annual Cashflows as Calibration Variables")
270
+
271
+ with gr.Row():
272
+ cf_base_table = gr.Dataframe(label="Base Scenario Comparison")
273
+ cf_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
274
+
275
+ cf_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
276
+ cf_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
277
+
278
+ with gr.Row():
279
+ cf_pv_base = gr.Dataframe(label="Present Values - Base")
280
+ cf_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
281
+ cf_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
282
+
283
+ with gr.TabItem("Policy Attribute Calibration"):
284
+ gr.Markdown("### Results using Policy Attributes as Calibration Variables")
285
+
286
+ with gr.Row():
287
+ attr_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
288
+ attr_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
289
+
290
+ attr_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
291
+ attr_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
292
+ attr_pv_base = gr.Dataframe(label="Present Values - Base Scenario")
293
+
294
+ with gr.TabItem("Present Value Calibration"):
295
+ gr.Markdown("### Results using Present Values as Calibration Variables")
296
+
297
+ with gr.Row():
298
+ pv_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
299
+ pv_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
300
+
301
+ pv_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
302
+ pv_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
303
+
304
+ with gr.Row():
305
+ pv_pv_base = gr.Dataframe(label="Present Values - Base")
306
+ pv_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
307
+ pv_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
308
+
309
+ def update_interface(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
310
+ if not all([cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort]):
311
+ return [None] * 17
312
+
313
+ results = process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort)
314
+
315
+ if "error" in results:
316
+ gr.Warning(results["error"])
317
+ return [None] * 17
318
+
319
+ return [
320
+ results.get('summary_plot'),
321
+ results.get('cf_base_table'),
322
+ results.get('cf_policy_attrs'),
323
+ results.get('cf_cashflow_plot'),
324
+ results.get('cf_scatter_base'),
325
+ results.get('cf_pv_base'),
326
+ results.get('cf_pv_lapse'),
327
+ results.get('cf_pv_mort'),
328
+ results.get('attr_cf_base'),
329
+ results.get('attr_policy_attrs'),
330
+ results.get('attr_cashflow_plot'),
331
+ results.get('attr_scatter_base'),
332
+ results.get('attr_pv_base'),
333
+ results.get('pv_cf_base'),
334
+ results.get('pv_policy_attrs'),
335
+ results.get('pv_cashflow_plot'),
336
+ results.get('pv_scatter_base'),
337
+ results.get('pv_pv_base'),
338
+ results.get('pv_pv_lapse'),
339
+ results.get('pv_pv_mort')
340
+ ]
341
+
342
+ analyze_btn.click(
343
+ update_interface,
344
+ inputs=[cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort],
345
+ outputs=[
346
+ summary_plot,
347
+ cf_base_table, cf_policy_attrs, cf_cashflow_plot, cf_scatter_base,
348
+ cf_pv_base, cf_pv_lapse, cf_pv_mort,
349
+ attr_cf_base, attr_policy_attrs, attr_cashflow_plot, attr_scatter_base, attr_pv_base,
350
+ pv_cf_base, pv_policy_attrs, pv_cashflow_plot, pv_scatter_base,
351
+ pv_pv_base, pv_pv_lapse, pv_pv_mort
352
+ ]
353
+ )
354
+
355
+ return demo
356
 
357
+ if __name__ == "__main__":
358
+ demo = create_interface()
359
+ demo.launch()