File size: 25,162 Bytes
7e17387
 
74d9c0e
7e17387
450e44d
7e17387
74d9c0e
7e17387
dd97346
74d9c0e
7e17387
9d69ed1
 
 
4e456a7
 
 
e0be832
4e456a7
 
 
9d69ed1
 
74d9c0e
4e456a7
450e44d
 
4e456a7
450e44d
4e456a7
 
 
 
dd97346
4e456a7
 
450e44d
4e456a7
dd97346
4e456a7
450e44d
 
4e456a7
 
450e44d
dd97346
 
 
4e456a7
 
450e44d
 
 
 
 
 
 
 
 
 
4e456a7
 
450e44d
4e456a7
450e44d
 
4e456a7
 
450e44d
 
e0be832
 
 
 
 
 
 
 
dd97346
e0be832
 
 
450e44d
e0be832
 
450e44d
e0be832
 
 
 
 
450e44d
e0be832
450e44d
e0be832
450e44d
 
4e456a7
e0be832
 
 
 
f647840
ee9fc3c
74d9c0e
450e44d
 
e0be832
4e456a7
450e44d
 
 
e0be832
dd97346
4e456a7
 
e0be832
4e456a7
dd97346
450e44d
4e456a7
450e44d
 
e0be832
450e44d
4e456a7
450e44d
 
dd97346
 
 
 
e0be832
dd97346
 
e0be832
dd97346
f647840
9d69ed1
450e44d
4e456a7
450e44d
 
 
 
 
 
 
 
 
 
 
e0be832
4e456a7
 
450e44d
 
4e456a7
 
dd97346
 
 
 
450e44d
e0be832
dd97346
 
 
 
 
 
4e456a7
450e44d
 
 
 
 
e0be832
450e44d
 
 
4e456a7
dd97346
 
 
 
 
 
 
f647840
9d69ed1
4e456a7
450e44d
4e456a7
450e44d
4e456a7
 
 
 
 
450e44d
dd97346
450e44d
dd97346
450e44d
 
 
 
4e456a7
 
 
 
 
 
dd97346
450e44d
 
e0be832
4e456a7
dd97346
 
450e44d
dd97346
450e44d
 
dd97346
 
 
450e44d
dd97346
450e44d
dd97346
 
450e44d
e0be832
450e44d
dd97346
450e44d
e0be832
450e44d
dd97346
 
4e456a7
450e44d
4e456a7
 
450e44d
e0be832
450e44d
 
 
 
 
 
dd97346
 
450e44d
dd97346
450e44d
 
dd97346
 
 
450e44d
dd97346
 
450e44d
dd97346
e0be832
450e44d
4e456a7
450e44d
e0be832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450e44d
e0be832
 
 
 
 
dd97346
450e44d
e0be832
 
 
 
 
450e44d
 
e0be832
8e2e740
450e44d
e0be832
 
 
 
 
4e456a7
450e44d
 
 
 
 
e0be832
 
 
450e44d
e0be832
dd97346
450e44d
 
 
 
dd97346
 
450e44d
4e456a7
 
dd97346
450e44d
dd97346
 
450e44d
dd97346
4e456a7
450e44d
 
9d69ed1
8e2e740
dd97346
e0be832
dd97346
 
450e44d
 
 
 
 
dd97346
 
 
450e44d
 
 
 
dd97346
 
4e456a7
dd97346
450e44d
 
 
 
4e456a7
dd97346
 
 
4e456a7
dd97346
 
 
4e456a7
dd97346
450e44d
 
4e456a7
 
dd97346
e0be832
4e456a7
dd97346
450e44d
dd97346
450e44d
 
 
 
 
dd97346
450e44d
 
 
4e456a7
dd97346
 
 
450e44d
 
 
 
 
 
dd97346
 
450e44d
dd97346
450e44d
 
 
 
 
4e456a7
450e44d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd97346
450e44d
 
 
4e456a7
 
450e44d
 
 
 
 
 
 
dd97346
450e44d
 
dd97346
450e44d
 
 
 
 
dd97346
450e44d
 
4e456a7
dd97346
450e44d
 
 
 
 
 
 
 
 
 
 
 
dd97346
 
 
450e44d
dd97346
 
450e44d
dd97346
4e456a7
450e44d
 
 
 
 
e0be832
450e44d
 
dd97346
 
 
 
 
450e44d
 
 
 
 
 
 
 
4e456a7
7e17387
74d9c0e
4e456a7
450e44d
 
 
 
4e456a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min, r2_score
import matplotlib.pyplot as plt
import matplotlib.cm
import io
import os # Added for path joining
from PIL import Image

# Define the paths for example data
EXAMPLE_DATA_DIR = "eg_data"
EXAMPLE_FILES = {
    "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
    "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
    "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
    "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"),
    "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
    "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
    "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
}

class Clusters:
    def __init__(self, loc_vars):
        self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
        closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
        
        rep_ids = pd.Series(data=(closest+1))  # 0-based to 1-based indexes
        rep_ids.name = 'policy_id'
        rep_ids.index.name = 'cluster_id'
        self.rep_ids = rep_ids
        
        self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']

    def agg_by_cluster(self, df, agg=None):
        """Aggregate columns by cluster"""
        temp = df.copy()
        temp['cluster_id'] = self.kmeans.labels_
        temp = temp.set_index('cluster_id')
        agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
        return temp.groupby(temp.index).agg(agg)

    def extract_reps(self, df):
        """Extract the rows of representative policies"""
        temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
        temp.index.name = 'cluster_id'
        return temp.drop('policy_id', axis=1)

    def extract_and_scale_reps(self, df, agg=None):
        """Extract and scale the rows of representative policies"""
        if agg:
            cols = df.columns
            mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
            # Ensure mult has same index as extract_reps(df) for proper alignment
            extracted_df = self.extract_reps(df)
            mult.index = extracted_df.index 
            return extracted_df.mul(mult)
        else:
            return self.extract_reps(df).mul(self.policy_count, axis=0)

    def compare(self, df, agg=None):
        """Returns a multi-indexed Dataframe comparing actual and estimate"""
        source = self.agg_by_cluster(df, agg)
        target = self.extract_and_scale_reps(df, agg)
        return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})

    def compare_total(self, df, agg=None):
        """Aggregate df by columns"""
        if agg:
            # Calculate actual values using specified aggregation
            actual_values = {}
            for col in df.columns:
                if agg.get(col, 'sum') == 'mean':
                    actual_values[col] = df[col].mean()
                else:  # sum
                    actual_values[col] = df[col].sum()
            actual = pd.Series(actual_values)
            
            # Calculate estimate values
            reps_unscaled = self.extract_reps(df)
            estimate_values = {}
            
            for col in df.columns:
                if agg.get(col, 'sum') == 'mean':
                    # Weighted average for mean columns
                    weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
                    total_weight = self.policy_count.sum()
                    estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
                else:  # sum
                    estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
            
            estimate = pd.Series(estimate_values)
            
        else:  # Original logic if no agg is specified (all sum)
            actual = df.sum()
            estimate = self.extract_and_scale_reps(df).sum()
        
        # Calculate error, handling division by zero
        error = np.where(actual != 0, estimate / actual - 1, 0)
        
        return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})


def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
    """Create cashflow comparison plots"""
    if not cfs_list or not cluster_obj or not titles:
        return None
    num_plots = len(cfs_list)
    if num_plots == 0:
        return None

    # Determine subplot layout
    cols = 2
    rows = (num_plots + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
    axes = axes.flatten()
    
    for i, (df, title) in enumerate(zip(cfs_list, titles)):
        if i < len(axes):
            comparison = cluster_obj.compare_total(df)
            comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
            axes[i].set_xlabel('Time')
            axes[i].set_ylabel('Value')
    
    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])
        
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100)
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    return img

def plot_scatter_comparison(df_compare_output, title):
    """Create scatter plot comparison from compare() output"""
    if df_compare_output is None or df_compare_output.empty:
        # Create a blank plot with a message
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
        ax.set_title(title)
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        img = Image.open(buf)
        plt.close(fig)
        return img

    fig, ax = plt.subplots(figsize=(12, 8))
    
    if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
         gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
         ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
    else:
        unique_levels = df_compare_output.index.get_level_values(1).unique()
        colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
        
        for item_level, color_val in zip(unique_levels, colors):
            subset = df_compare_output.xs(item_level, level=1)
            ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
        if len(unique_levels) > 1 and len(unique_levels) <= 10:
            ax.legend(title=df_compare_output.index.names[1])

    ax.set_xlabel('Actual')
    ax.set_ylabel('Estimate')
    ax.set_title(title)
    ax.grid(True)
    
    # Draw identity line
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]
    if lims[0] != lims[1]:
      ax.plot(lims, lims, 'r-', linewidth=0.5)
      ax.set_xlim(lims)
      ax.set_ylim(lims)
    
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100)
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    return img


def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
                  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
    """Main processing function - now accepts file paths"""
    try:
        # Read uploaded files using paths
        cfs = pd.read_excel(cashflow_base_path, index_col=0)
        cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
        cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
        
        pol_data_full = pd.read_excel(policy_data_path, index_col=0)
        # Ensure the correct columns are selected for pol_data
        required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
        if all(col in pol_data_full.columns for col in required_cols):
            pol_data = pol_data_full[required_cols]
        else:
            gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
            pol_data = pol_data_full

        pvs = pd.read_excel(pv_base_path, index_col=0)
        pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
        pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
        
        cfs_list = [cfs, cfs_lapse50, cfs_mort15]
        scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
        
        results = {}
        
        mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}

        # --- 1. Cashflow Calibration ---
        cluster_cfs = Clusters(cfs)
        
        results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
        results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
        
        results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
        results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
        results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
        
        results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
        results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')

        # --- 2. Policy Attribute Calibration ---
        # Standardize policy attributes
        if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0:
             loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
        else:
            gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
            loc_vars_attrs = pol_data
        
        if not loc_vars_attrs.empty:
            cluster_attrs = Clusters(loc_vars_attrs)
            results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
            results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
            results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
            results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
            results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
        else:
            results['attr_total_cf_base'] = pd.DataFrame()
            results['attr_policy_attrs_total'] = pd.DataFrame()
            results['attr_total_pv_base'] = pd.DataFrame()
            results['attr_cashflow_plot'] = None
            results['attr_scatter_cashflows_base'] = None

        # --- 3. Present Value Calibration ---
        cluster_pvs = Clusters(pvs)
        
        results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
        results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
        
        results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
        results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
        results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
        
        results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
        results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')

        # --- Summary Comparison Plot Data ---
        # Error metric for key PV column or mean absolute error
        
        error_data = {}
        
        # Function to safely get error value
        def get_error_safe(compare_result, col_name=None):
            if compare_result.empty:
                return np.nan
            if col_name and col_name in compare_result.index:
                return abs(compare_result.loc[col_name, 'error'])
            else:
                # Use mean absolute error if specific column not found
                return abs(compare_result['error']).mean()
        
        # Determine key PV column (try common names)
        key_pv_col = None
        for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
            if potential_col in pvs.columns:
                key_pv_col = potential_col
                break
        
        # Cashflow Calibration Errors
        error_data['CF Calib.'] = [
            get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
            get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
            get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
        ]

        # Policy Attribute Calibration Errors
        if not loc_vars_attrs.empty:
            error_data['Attr Calib.'] = [
                get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
                get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
                get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
            ]
        else:
            error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]

        # Present Value Calibration Errors
        error_data['PV Calib.'] = [
            get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
            get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
            get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
        ]
        
        # Create Summary Plot
        summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
        
        fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
        summary_df.plot(kind='bar', ax=ax_summary, grid=True)
        ax_summary.set_ylabel('Absolute Error Rate')
        title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
        ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
        ax_summary.tick_params(axis='x', rotation=0)
        ax_summary.legend(title='Calibration Method')
        plt.tight_layout()
        
        buf_summary = io.BytesIO()
        plt.savefig(buf_summary, format='png', dpi=100)
        buf_summary.seek(0)
        results['summary_plot'] = Image.open(buf_summary)
        plt.close(fig_summary)
        
        return results
        
    except FileNotFoundError as e:
        gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
        return {"error": f"File not found: {e.filename}"}
    except KeyError as e:
        gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
        return {"error": f"Missing column: {e}"}
    except Exception as e:
        gr.Error(f"Error processing files: {str(e)}")
        return {"error": f"Error processing files: {str(e)}"}


def create_interface():
    with gr.Blocks(title="Cluster Model Points Analysis") as demo:
        gr.Markdown("""
        # Cluster Model Points Analysis
        
        This application applies cluster analysis to model point selection for insurance portfolios.
        Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
        
        **Required Files (Excel .xlsx):**
        - Cashflows - Base Scenario
        - Cashflows - Lapse Stress (+50%)
        - Cashflows - Mortality Stress (+15%)
        - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
        - Present Values - Base Scenario
        - Present Values - Lapse Stress
        - Present Values - Mortality Stress
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Upload Files or Load Examples")
                
                load_example_btn = gr.Button("Load Example Data")
                
                with gr.Row():
                    cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
                    cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
                    cashflow_mort_input = gr.File(label="Cashflows - Mortality Stress", file_types=[".xlsx"])
                with gr.Row():
                    policy_data_input = gr.File(label="Policy Data", file_types=[".xlsx"])
                    pv_base_input = gr.File(label="Present Values - Base", file_types=[".xlsx"])
                    pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
                with gr.Row():
                    pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
                
                analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
        
        with gr.Tabs():
            with gr.TabItem("📊 Summary"):
                summary_plot_output = gr.Image(label="Calibration Methods Comparison")
            
            with gr.TabItem("💸 Cashflow Calibration"):
                gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
                with gr.Row():
                    cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                    with gr.Row():
                        cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
                        cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
                        cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
            
            with gr.TabItem("👤 Policy Attribute Calibration"):
                gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
                with gr.Row():
                    attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                     attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")

            with gr.TabItem("💰 Present Value Calibration"):
                gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
                with gr.Row():
                    pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
                    pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
                pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
                pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
                with gr.Accordion("Present Value Comparisons (Total)", open=False):
                    with gr.Row():
                        pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
                        pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
                        pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")

        # --- Helper function to prepare outputs ---
        def get_all_output_components():
            return [
                summary_plot_output,
                # Cashflow Calib Outputs
                cf_total_base_table_out, cf_policy_attrs_total_out,
                cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
                cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
                # Attribute Calib Outputs
                attr_total_cf_base_out, attr_policy_attrs_total_out,
                attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
                # PV Calib Outputs
                pv_total_cf_base_out, pv_policy_attrs_total_out,
                pv_cashflow_plot_out, pv_scatter_pvs_base_out,
                pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
            ]
        
        # --- Action for Analyze Button ---
        def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
            files = [f1, f2, f3, f4, f5, f6, f7]
            
            file_paths = []
            for i, f_obj in enumerate(files):
                if f_obj is None:
                    gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
                    return [None] * len(get_all_output_components())
                
                # If f_obj is a Gradio FileData object (from direct upload)
                if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
                    file_paths.append(f_obj.name)
                # If f_obj is already a string path (from example load)
                elif isinstance(f_obj, str):
                     file_paths.append(f_obj)
                else:
                    gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
                    return [None] * len(get_all_output_components())

            results = process_files(*file_paths)
            
            if "error" in results:
                return [None] * len(get_all_output_components())
            
            return [
                results.get('summary_plot'),
                # CF Calib
                results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
                results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
                results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
                # Attr Calib
                results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
                results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
                # PV Calib
                results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
                results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
                results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
            ]

        analyze_btn.click(
            handle_analysis,
            inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
                    policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
            outputs=get_all_output_components()
        )

        # --- Action for Load Example Data Button ---
        def load_example_files():
            missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
            if missing_files:
                gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
                return [None] * 7
            
            gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
            return [
                EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
                EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
                EXAMPLE_FILES["pv_mort"]
            ]

        load_example_btn.click(
            load_example_files,
            inputs=[],
            outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
                     policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
        )
            
    return demo

if __name__ == "__main__":
    if not os.path.exists(EXAMPLE_DATA_DIR):
        os.makedirs(EXAMPLE_DATA_DIR)
        print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
        print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")

    demo_app = create_interface()
    demo_app.launch()