import gradio as gr import numpy as np import pandas as pd from sklearn.cluster import KMeans from sklearn.metrics import pairwise_distances_argmin_min, r2_score import matplotlib.pyplot as plt import matplotlib.cm import io import base64 from PIL import Image class Clusters: def __init__(self, loc_vars): self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars)) closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars)) rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes rep_ids.name = 'policy_id' rep_ids.index.name = 'cluster_id' self.rep_ids = rep_ids self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count'] def agg_by_cluster(self, df, agg=None): """Aggregate columns by cluster""" temp = df.copy() temp['cluster_id'] = self.kmeans.labels_ temp = temp.set_index('cluster_id') agg = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} if agg else "sum" return temp.groupby(temp.index).agg(agg) def extract_reps(self, df): """Extract the rows of representative policies""" temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id') temp.index.name = 'cluster_id' return temp.drop('policy_id', axis=1) def extract_and_scale_reps(self, df, agg=None): """Extract and scale the rows of representative policies""" if agg: cols = df.columns mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols}) return self.extract_reps(df).mul(mult) else: return self.extract_reps(df).mul(self.policy_count, axis=0) def compare(self, df, agg=None): """Returns a multi-indexed Dataframe comparing actual and estimate""" source = self.agg_by_cluster(df, agg) target = self.extract_and_scale_reps(df, agg) return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()}) def compare_total(self, df, agg=None): """Aggregate df by columns""" if agg: cols = df.columns op = {c: (agg[c] if c in agg else 'sum') for c in df.columns} actual = df.agg(op) estimate = self.extract_and_scale_reps(df, agg=op) op = {k: ((lambda s: s.dot(self.policy_count) / self.policy_count.sum()) if v == 'mean' else v) for k, v in op.items()} estimate = estimate.agg(op) else: actual = df.sum() estimate = self.extract_and_scale_reps(df).sum() return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1}) def create_plot(plot_func, *args, **kwargs): """Helper function to create plots and return as image""" plt.figure(figsize=(10, 6)) plot_func(*args, **kwargs) # Save plot to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close() return Image.open(buf) def plot_cashflows_comparison(cfs_list, cluster_obj, titles): """Create cashflow comparison plots""" fig, axes = plt.subplots(2, 2, figsize=(15, 10)) axes = axes.flatten() for i, (df, title) in enumerate(zip(cfs_list, titles)): if i < len(axes): comparison = cluster_obj.compare_total(df) comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close() return Image.open(buf) def plot_scatter_comparison(df, title): """Create scatter plot comparison""" plt.figure(figsize=(12, 8)) colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(df.index.levels[1]))) for y, c in zip(df.index.levels[1], colors): plt.scatter(df.xs(y, level=1)['actual'], df.xs(y, level=1)['estimate'], color=c, s=9, alpha=0.6) plt.xlabel('Actual') plt.ylabel('Estimate') plt.title(title) plt.grid(True) # Draw identity line lims = [ np.min([plt.xlim(), plt.ylim()]), np.max([plt.xlim(), plt.ylim()]), ] plt.plot(lims, lims, 'r-', linewidth=0.5) plt.xlim(lims) plt.ylim(lims) buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close() return Image.open(buf) def process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort): """Main processing function""" try: # Read uploaded files cfs = pd.read_excel(cashflow_base.name, index_col=0) cfs_lapse50 = pd.read_excel(cashflow_lapse.name, index_col=0) cfs_mort15 = pd.read_excel(cashflow_mort.name, index_col=0) pol_data = pd.read_excel(policy_data.name, index_col=0) if pol_data.shape[1] > 4: pol_data = pol_data[['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']] pvs = pd.read_excel(pv_base.name, index_col=0) pvs_lapse50 = pd.read_excel(pv_lapse.name, index_col=0) pvs_mort15 = pd.read_excel(pv_mort.name, index_col=0) cfs_list = [cfs, cfs_lapse50, cfs_mort15] pvs_list = [pvs, pvs_lapse50, pvs_mort15] scen_titles = ['Base', 'Lapse+50%', 'Mort+15%'] results = {} # 1. Cashflow Calibration cluster_cfs = Clusters(cfs) # Cashflow comparison tables results['cf_base_table'] = cluster_cfs.compare_total(cfs) results['cf_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) results['cf_mort_table'] = cluster_cfs.compare_total(cfs_mort15) # Policy attributes analysis mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean'} results['cf_policy_attrs'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs) # Present value analysis results['cf_pv_base'] = cluster_cfs.compare_total(pvs) results['cf_pv_lapse'] = cluster_cfs.compare_total(pvs_lapse50) results['cf_pv_mort'] = cluster_cfs.compare_total(pvs_mort15) # Create plots for cashflow calibration results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles) results['cf_scatter_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calibration - Base Scenario') # 2. Policy Attribute Calibration loc_vars = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min()) cluster_attrs = Clusters(loc_vars) results['attr_cf_base'] = cluster_attrs.compare_total(cfs) results['attr_policy_attrs'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) results['attr_pv_base'] = cluster_attrs.compare_total(pvs) results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles) results['attr_scatter_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attribute Calibration - Base Scenario') # 3. Present Value Calibration cluster_pvs = Clusters(pvs) results['pv_cf_base'] = cluster_pvs.compare_total(cfs) results['pv_policy_attrs'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs) results['pv_pv_base'] = cluster_pvs.compare_total(pvs) results['pv_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) results['pv_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles) results['pv_scatter_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'Present Value Calibration - Base Scenario') # Summary comparison plot fig, ax = plt.subplots(figsize=(12, 8)) comparison_data = { 'Cashflow Calibration': [ abs(cluster_cfs.compare_total(cfs)['error'].mean()), abs(cluster_cfs.compare_total(pvs)['error'].mean()) ], 'Policy Attribute Calibration': [ abs(cluster_attrs.compare_total(cfs)['error'].mean()), abs(cluster_attrs.compare_total(pvs)['error'].mean()) ], 'Present Value Calibration': [ abs(cluster_pvs.compare_total(cfs)['error'].mean()), abs(cluster_pvs.compare_total(pvs)['error'].mean()) ] } x = np.arange(2) width = 0.25 ax.bar(x - width, comparison_data['Cashflow Calibration'], width, label='Cashflow Calibration') ax.bar(x, comparison_data['Policy Attribute Calibration'], width, label='Policy Attribute Calibration') ax.bar(x + width, comparison_data['Present Value Calibration'], width, label='Present Value Calibration') ax.set_ylabel('Mean Absolute Error') ax.set_title('Calibration Method Comparison') ax.set_xticks(x) ax.set_xticklabels(['Cashflows', 'Present Values']) ax.legend() ax.grid(True, alpha=0.3) buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close() results['summary_plot'] = Image.open(buf) return results except Exception as e: return {"error": f"Error processing files: {str(e)}"} def create_interface(): with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Cluster Model Points Analysis This application applies cluster analysis to model point selection for insurance portfolios. Upload your Excel files to analyze cashflows, policy attributes, and present values using different calibration methods. **Required Files:** - 3 Cashflow files (Base, Lapse stress, Mortality stress scenarios) - 1 Policy data file - 3 Present value files (Base, Lapse stress, Mortality stress scenarios) """) with gr.Row(): with gr.Column(): gr.Markdown("### Upload Files") cashflow_base = gr.File(label="Cashflows - Base Scenario", file_types=[".xlsx"]) cashflow_lapse = gr.File(label="Cashflows - Lapse Stress (+50%)", file_types=[".xlsx"]) cashflow_mort = gr.File(label="Cashflows - Mortality Stress (+15%)", file_types=[".xlsx"]) policy_data = gr.File(label="Policy Data", file_types=[".xlsx"]) pv_base = gr.File(label="Present Values - Base Scenario", file_types=[".xlsx"]) pv_lapse = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"]) pv_mort = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"]) analyze_btn = gr.Button("Analyze", variant="primary", size="lg") with gr.Tabs(): with gr.TabItem("Summary"): summary_plot = gr.Image(label="Calibration Methods Comparison") with gr.TabItem("Cashflow Calibration"): gr.Markdown("### Results using Annual Cashflows as Calibration Variables") with gr.Row(): cf_base_table = gr.Dataframe(label="Base Scenario Comparison") cf_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") cf_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") cf_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") with gr.Row(): cf_pv_base = gr.Dataframe(label="Present Values - Base") cf_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress") cf_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress") with gr.TabItem("Policy Attribute Calibration"): gr.Markdown("### Results using Policy Attributes as Calibration Variables") with gr.Row(): attr_cf_base = gr.Dataframe(label="Cashflows - Base Scenario") attr_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") attr_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") attr_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") attr_pv_base = gr.Dataframe(label="Present Values - Base Scenario") with gr.TabItem("Present Value Calibration"): gr.Markdown("### Results using Present Values as Calibration Variables") with gr.Row(): pv_cf_base = gr.Dataframe(label="Cashflows - Base Scenario") pv_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison") pv_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios") pv_scatter_base = gr.Image(label="Scatter Plot - Base Scenario") with gr.Row(): pv_pv_base = gr.Dataframe(label="Present Values - Base") pv_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress") pv_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress") def update_interface(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort): if not all([cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort]): return [None] * 17 results = process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort) if "error" in results: gr.Warning(results["error"]) return [None] * 17 return [ results.get('summary_plot'), results.get('cf_base_table'), results.get('cf_policy_attrs'), results.get('cf_cashflow_plot'), results.get('cf_scatter_base'), results.get('cf_pv_base'), results.get('cf_pv_lapse'), results.get('cf_pv_mort'), results.get('attr_cf_base'), results.get('attr_policy_attrs'), results.get('attr_cashflow_plot'), results.get('attr_scatter_base'), results.get('attr_pv_base'), results.get('pv_cf_base'), results.get('pv_policy_attrs'), results.get('pv_cashflow_plot'), results.get('pv_scatter_base'), results.get('pv_pv_base'), results.get('pv_pv_lapse'), results.get('pv_pv_mort') ] analyze_btn.click( update_interface, inputs=[cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort], outputs=[ summary_plot, cf_base_table, cf_policy_attrs, cf_cashflow_plot, cf_scatter_base, cf_pv_base, cf_pv_lapse, cf_pv_mort, attr_cf_base, attr_policy_attrs, attr_cashflow_plot, attr_scatter_base, attr_pv_base, pv_cf_base, pv_policy_attrs, pv_cashflow_plot, pv_scatter_base, pv_pv_base, pv_pv_lapse, pv_pv_mort ] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()