Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.cluster import KMeans | |
| from sklearn.metrics import pairwise_distances_argmin_min, r2_score | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm | |
| import io | |
| import base64 | |
| from PIL import Image | |
| class Clusters: | |
| def __init__(self, loc_vars): | |
| self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) | |
| closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) | |
| rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes | |
| rep_ids.name = 'policy_id' | |
| rep_ids.index.name = 'cluster_id' | |
| self.rep_ids = rep_ids | |
| self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] | |
| def agg_by_cluster(self, df, agg=None): | |
| """Aggregate columns by cluster""" | |
| temp = df.copy() | |
| temp['cluster_id'] = self.kmeans.labels_ | |
| temp = temp.set_index('cluster_id') | |
| agg = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} if agg else "sum" | |
| return temp.groupby(temp.index).agg(agg) | |
| def extract_reps(self, df): | |
| """Extract the rows of representative policies""" | |
| temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') | |
| temp.index.name = 'cluster_id' | |
| return temp.drop('policy_id', axis=1) | |
| def extract_and_scale_reps(self, df, agg=None): | |
| """Extract and scale the rows of representative policies""" | |
| if agg: | |
| cols = df.columns | |
| mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols}) | |
| return self.extract_reps(df).mul(mult) | |
| else: | |
| return self.extract_reps(df).mul(self.policy_count, axis=0) | |
| def compare(self, df, agg=None): | |
| """Returns a multi-indexed Dataframe comparing actual and estimate""" | |
| source = self.agg_by_cluster(df, agg) | |
| target = self.extract_and_scale_reps(df, agg) | |
| return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()}) | |
| def compare_total(self, df, agg=None): | |
| """Aggregate df by columns""" | |
| if agg: | |
| cols = df.columns | |
| op = {c: (agg[c] if c in agg else 'sum') for c in df.columns} | |
| actual = df.agg(op) | |
| estimate = self.extract_and_scale_reps(df, agg=op) | |
| op = {k: ((lambda s: s.dot(self.policy_count) / self.policy_count.sum()) if v == 'mean' else v) for k, v in op.items()} | |
| estimate = estimate.agg(op) | |
| else: | |
| actual = df.sum() | |
| estimate = self.extract_and_scale_reps(df).sum() | |
| return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1}) | |
| def create_plot(plot_func, *args, **kwargs): | |
| """Helper function to create plots and return as image""" | |
| plt.figure(figsize=(10, 6)) | |
| plot_func(*args, **kwargs) | |
| # Save plot to bytes | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| def plot_cashflows_comparison(cfs_list, cluster_obj, titles): | |
| """Create cashflow comparison plots""" | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 10)) | |
| axes = axes.flatten() | |
| for i, (df, title) in enumerate(zip(cfs_list, titles)): | |
| if i < len(axes): | |
| comparison = cluster_obj.compare_total(df) | |
| comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| def plot_scatter_comparison(df, title): | |
| """Create scatter plot comparison""" | |
| plt.figure(figsize=(12, 8)) | |
| colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(df.index.levels[1]))) | |
| for y, c in zip(df.index.levels[1], colors): | |
| plt.scatter(df.xs(y, level=1)['actual'], df.xs(y, level=1)['estimate'], | |
| color=c, s=9, alpha=0.6) | |
| plt.xlabel('Actual') | |
| plt.ylabel('Estimate') | |
| plt.title(title) | |
| plt.grid(True) | |
| # Draw identity line | |
| lims = [ | |
| np.min([plt.xlim(), plt.ylim()]), | |
| np.max([plt.xlim(), plt.ylim()]), | |
| ] | |
| plt.plot(lims, lims, 'r-', linewidth=0.5) | |
| plt.xlim(lims) | |
| plt.ylim(lims) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| def process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort): | |
| """Main processing function""" | |
| try: | |
| # Read uploaded files | |
| cfs = pd.read_excel(cashflow_base.name, index_col=0) | |
| cfs_lapse50 = pd.read_excel(cashflow_lapse.name, index_col=0) | |
| cfs_mort15 = pd.read_excel(cashflow_mort.name, index_col=0) | |
| pol_data = pd.read_excel(policy_data.name, index_col=0) | |
| if pol_data.shape[1] > 4: | |
| pol_data = pol_data[['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']] | |
| pvs = pd.read_excel(pv_base.name, index_col=0) | |
| pvs_lapse50 = pd.read_excel(pv_lapse.name, index_col=0) | |
| pvs_mort15 = pd.read_excel(pv_mort.name, index_col=0) | |
| cfs_list = [cfs, cfs_lapse50, cfs_mort15] | |
| pvs_list = [pvs, pvs_lapse50, pvs_mort15] | |
| scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] | |
| results = {} | |
| # 1. Cashflow Calibration | |
| cluster_cfs = Clusters(cfs) | |
| # Cashflow comparison tables | |
| results['cf_base_table'] = cluster_cfs.compare_total(cfs) | |
| results['cf_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) | |
| results['cf_mort_table'] = cluster_cfs.compare_total(cfs_mort15) | |
| # Policy attributes analysis | |
| mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean'} | |
| results['cf_policy_attrs'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs) | |
| # Present value analysis | |
| results['cf_pv_base'] = cluster_cfs.compare_total(pvs) | |
| results['cf_pv_lapse'] = cluster_cfs.compare_total(pvs_lapse50) | |
| results['cf_pv_mort'] = cluster_cfs.compare_total(pvs_mort15) | |
| # Create plots for cashflow calibration | |
| results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) | |
| results['cf_scatter_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calibration - Base Scenario') | |
| # 2. Policy Attribute Calibration | |
| loc_vars = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min()) | |
| cluster_attrs = Clusters(loc_vars) | |
| results['attr_cf_base'] = cluster_attrs.compare_total(cfs) | |
| results['attr_policy_attrs'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) | |
| results['attr_pv_base'] = cluster_attrs.compare_total(pvs) | |
| results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) | |
| results['attr_scatter_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attribute Calibration - Base Scenario') | |
| # 3. Present Value Calibration | |
| cluster_pvs = Clusters(pvs) | |
| results['pv_cf_base'] = cluster_pvs.compare_total(cfs) | |
| results['pv_policy_attrs'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs) | |
| results['pv_pv_base'] = cluster_pvs.compare_total(pvs) | |
| results['pv_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) | |
| results['pv_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) | |
| results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) | |
| results['pv_scatter_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'Present Value Calibration - Base Scenario') | |
| # Summary comparison plot | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| comparison_data = { | |
| 'Cashflow Calibration': [ | |
| abs(cluster_cfs.compare_total(cfs)['error'].mean()), | |
| abs(cluster_cfs.compare_total(pvs)['error'].mean()) | |
| ], | |
| 'Policy Attribute Calibration': [ | |
| abs(cluster_attrs.compare_total(cfs)['error'].mean()), | |
| abs(cluster_attrs.compare_total(pvs)['error'].mean()) | |
| ], | |
| 'Present Value Calibration': [ | |
| abs(cluster_pvs.compare_total(cfs)['error'].mean()), | |
| abs(cluster_pvs.compare_total(pvs)['error'].mean()) | |
| ] | |
| } | |
| x = np.arange(2) | |
| width = 0.25 | |
| ax.bar(x - width, comparison_data['Cashflow Calibration'], width, label='Cashflow Calibration') | |
| ax.bar(x, comparison_data['Policy Attribute Calibration'], width, label='Policy Attribute Calibration') | |
| ax.bar(x + width, comparison_data['Present Value Calibration'], width, label='Present Value Calibration') | |
| ax.set_ylabel('Mean Absolute Error') | |
| ax.set_title('Calibration Method Comparison') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(['Cashflows', 'Present Values']) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| results['summary_plot'] = Image.open(buf) | |
| return results | |
| except Exception as e: | |
| return {"error": f"Error processing files: {str(e)}"} | |
| def create_interface(): | |
| with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Cluster Model Points Analysis | |
| This application applies cluster analysis to model point selection for insurance portfolios. | |
| Upload your Excel files to analyze cashflows, policy attributes, and present values using different calibration methods. | |
| **Required Files:** | |
| - 3 Cashflow files (Base, Lapse stress, Mortality stress scenarios) | |
| - 1 Policy data file | |
| - 3 Present value files (Base, Lapse stress, Mortality stress scenarios) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Upload Files") | |
| cashflow_base = gr.File(label="Cashflows - Base Scenario", file_types=[".xlsx"]) | |
| cashflow_lapse = gr.File(label="Cashflows - Lapse Stress (+50%)", file_types=[".xlsx"]) | |
| cashflow_mort = gr.File(label="Cashflows - Mortality Stress (+15%)", file_types=[".xlsx"]) | |
| policy_data = gr.File(label="Policy Data", file_types=[".xlsx"]) | |
| pv_base = gr.File(label="Present Values - Base Scenario", file_types=[".xlsx"]) | |
| pv_lapse = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) | |
| pv_mort = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) | |
| analyze_btn = gr.Button("Analyze", variant="primary", size="lg") | |
| with gr.Tabs(): | |
| with gr.TabItem("Summary"): | |
| summary_plot = gr.Image(label="Calibration Methods Comparison") | |
| with gr.TabItem("Cashflow Calibration"): | |
| gr.Markdown("### Results using Annual Cashflows as Calibration Variables") | |
| with gr.Row(): | |
| cf_base_table = gr.Dataframe(label="Base Scenario Comparison") | |
| cf_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") | |
| cf_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") | |
| cf_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") | |
| with gr.Row(): | |
| cf_pv_base = gr.Dataframe(label="Present Values - Base") | |
| cf_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress") | |
| cf_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress") | |
| with gr.TabItem("Policy Attribute Calibration"): | |
| gr.Markdown("### Results using Policy Attributes as Calibration Variables") | |
| with gr.Row(): | |
| attr_cf_base = gr.Dataframe(label="Cashflows - Base Scenario") | |
| attr_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") | |
| attr_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") | |
| attr_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") | |
| attr_pv_base = gr.Dataframe(label="Present Values - Base Scenario") | |
| with gr.TabItem("Present Value Calibration"): | |
| gr.Markdown("### Results using Present Values as Calibration Variables") | |
| with gr.Row(): | |
| pv_cf_base = gr.Dataframe(label="Cashflows - Base Scenario") | |
| pv_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") | |
| pv_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") | |
| pv_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") | |
| with gr.Row(): | |
| pv_pv_base = gr.Dataframe(label="Present Values - Base") | |
| pv_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress") | |
| pv_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress") | |
| def update_interface(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort): | |
| if not all([cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort]): | |
| return [None] * 17 | |
| results = process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort) | |
| if "error" in results: | |
| gr.Warning(results["error"]) | |
| return [None] * 17 | |
| return [ | |
| results.get('summary_plot'), | |
| results.get('cf_base_table'), | |
| results.get('cf_policy_attrs'), | |
| results.get('cf_cashflow_plot'), | |
| results.get('cf_scatter_base'), | |
| results.get('cf_pv_base'), | |
| results.get('cf_pv_lapse'), | |
| results.get('cf_pv_mort'), | |
| results.get('attr_cf_base'), | |
| results.get('attr_policy_attrs'), | |
| results.get('attr_cashflow_plot'), | |
| results.get('attr_scatter_base'), | |
| results.get('attr_pv_base'), | |
| results.get('pv_cf_base'), | |
| results.get('pv_policy_attrs'), | |
| results.get('pv_cashflow_plot'), | |
| results.get('pv_scatter_base'), | |
| results.get('pv_pv_base'), | |
| results.get('pv_pv_lapse'), | |
| results.get('pv_pv_mort') | |
| ] | |
| analyze_btn.click( | |
| update_interface, | |
| inputs=[cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort], | |
| outputs=[ | |
| summary_plot, | |
| cf_base_table, cf_policy_attrs, cf_cashflow_plot, cf_scatter_base, | |
| cf_pv_base, cf_pv_lapse, cf_pv_mort, | |
| attr_cf_base, attr_policy_attrs, attr_cashflow_plot, attr_scatter_base, attr_pv_base, | |
| pv_cf_base, pv_policy_attrs, pv_cashflow_plot, pv_scatter_base, | |
| pv_pv_base, pv_pv_lapse, pv_pv_mort | |
| ] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() |