alidenewade's picture
Update app.py
74d9c0e verified
raw
history blame
16.1 kB
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min, r2_score
import matplotlib.pyplot as plt
import matplotlib.cm
import io
import base64
from PIL import Image
class Clusters:
def __init__(self, loc_vars):
self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
rep_ids.name = 'policy_id'
rep_ids.index.name = 'cluster_id'
self.rep_ids = rep_ids
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
def agg_by_cluster(self, df, agg=None):
"""Aggregate columns by cluster"""
temp = df.copy()
temp['cluster_id'] = self.kmeans.labels_
temp = temp.set_index('cluster_id')
agg = {c: (agg[c] if c in agg else 'sum') for c in temp.columns} if agg else "sum"
return temp.groupby(temp.index).agg(agg)
def extract_reps(self, df):
"""Extract the rows of representative policies"""
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
temp.index.name = 'cluster_id'
return temp.drop('policy_id', axis=1)
def extract_and_scale_reps(self, df, agg=None):
"""Extract and scale the rows of representative policies"""
if agg:
cols = df.columns
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
return self.extract_reps(df).mul(mult)
else:
return self.extract_reps(df).mul(self.policy_count, axis=0)
def compare(self, df, agg=None):
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
source = self.agg_by_cluster(df, agg)
target = self.extract_and_scale_reps(df, agg)
return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
def compare_total(self, df, agg=None):
"""Aggregate df by columns"""
if agg:
cols = df.columns
op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
actual = df.agg(op)
estimate = self.extract_and_scale_reps(df, agg=op)
op = {k: ((lambda s: s.dot(self.policy_count) / self.policy_count.sum()) if v == 'mean' else v) for k, v in op.items()}
estimate = estimate.agg(op)
else:
actual = df.sum()
estimate = self.extract_and_scale_reps(df).sum()
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
def create_plot(plot_func, *args, **kwargs):
"""Helper function to create plots and return as image"""
plt.figure(figsize=(10, 6))
plot_func(*args, **kwargs)
# Save plot to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
return Image.open(buf)
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
"""Create cashflow comparison plots"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()
for i, (df, title) in enumerate(zip(cfs_list, titles)):
if i < len(axes):
comparison = cluster_obj.compare_total(df)
comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
return Image.open(buf)
def plot_scatter_comparison(df, title):
"""Create scatter plot comparison"""
plt.figure(figsize=(12, 8))
colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(df.index.levels[1])))
for y, c in zip(df.index.levels[1], colors):
plt.scatter(df.xs(y, level=1)['actual'], df.xs(y, level=1)['estimate'],
color=c, s=9, alpha=0.6)
plt.xlabel('Actual')
plt.ylabel('Estimate')
plt.title(title)
plt.grid(True)
# Draw identity line
lims = [
np.min([plt.xlim(), plt.ylim()]),
np.max([plt.xlim(), plt.ylim()]),
]
plt.plot(lims, lims, 'r-', linewidth=0.5)
plt.xlim(lims)
plt.ylim(lims)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
return Image.open(buf)
def process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
"""Main processing function"""
try:
# Read uploaded files
cfs = pd.read_excel(cashflow_base.name, index_col=0)
cfs_lapse50 = pd.read_excel(cashflow_lapse.name, index_col=0)
cfs_mort15 = pd.read_excel(cashflow_mort.name, index_col=0)
pol_data = pd.read_excel(policy_data.name, index_col=0)
if pol_data.shape[1] > 4:
pol_data = pol_data[['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']]
pvs = pd.read_excel(pv_base.name, index_col=0)
pvs_lapse50 = pd.read_excel(pv_lapse.name, index_col=0)
pvs_mort15 = pd.read_excel(pv_mort.name, index_col=0)
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
pvs_list = [pvs, pvs_lapse50, pvs_mort15]
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
results = {}
# 1. Cashflow Calibration
cluster_cfs = Clusters(cfs)
# Cashflow comparison tables
results['cf_base_table'] = cluster_cfs.compare_total(cfs)
results['cf_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50)
results['cf_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
# Policy attributes analysis
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean'}
results['cf_policy_attrs'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
# Present value analysis
results['cf_pv_base'] = cluster_cfs.compare_total(pvs)
results['cf_pv_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
results['cf_pv_mort'] = cluster_cfs.compare_total(pvs_mort15)
# Create plots for cashflow calibration
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
results['cf_scatter_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calibration - Base Scenario')
# 2. Policy Attribute Calibration
loc_vars = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
cluster_attrs = Clusters(loc_vars)
results['attr_cf_base'] = cluster_attrs.compare_total(cfs)
results['attr_policy_attrs'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
results['attr_pv_base'] = cluster_attrs.compare_total(pvs)
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
results['attr_scatter_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attribute Calibration - Base Scenario')
# 3. Present Value Calibration
cluster_pvs = Clusters(pvs)
results['pv_cf_base'] = cluster_pvs.compare_total(cfs)
results['pv_policy_attrs'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
results['pv_pv_base'] = cluster_pvs.compare_total(pvs)
results['pv_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
results['pv_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
results['pv_scatter_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'Present Value Calibration - Base Scenario')
# Summary comparison plot
fig, ax = plt.subplots(figsize=(12, 8))
comparison_data = {
'Cashflow Calibration': [
abs(cluster_cfs.compare_total(cfs)['error'].mean()),
abs(cluster_cfs.compare_total(pvs)['error'].mean())
],
'Policy Attribute Calibration': [
abs(cluster_attrs.compare_total(cfs)['error'].mean()),
abs(cluster_attrs.compare_total(pvs)['error'].mean())
],
'Present Value Calibration': [
abs(cluster_pvs.compare_total(cfs)['error'].mean()),
abs(cluster_pvs.compare_total(pvs)['error'].mean())
]
}
x = np.arange(2)
width = 0.25
ax.bar(x - width, comparison_data['Cashflow Calibration'], width, label='Cashflow Calibration')
ax.bar(x, comparison_data['Policy Attribute Calibration'], width, label='Policy Attribute Calibration')
ax.bar(x + width, comparison_data['Present Value Calibration'], width, label='Present Value Calibration')
ax.set_ylabel('Mean Absolute Error')
ax.set_title('Calibration Method Comparison')
ax.set_xticks(x)
ax.set_xticklabels(['Cashflows', 'Present Values'])
ax.legend()
ax.grid(True, alpha=0.3)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
results['summary_plot'] = Image.open(buf)
return results
except Exception as e:
return {"error": f"Error processing files: {str(e)}"}
def create_interface():
with gr.Blocks(title="Cluster Model Points Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Cluster Model Points Analysis
This application applies cluster analysis to model point selection for insurance portfolios.
Upload your Excel files to analyze cashflows, policy attributes, and present values using different calibration methods.
**Required Files:**
- 3 Cashflow files (Base, Lapse stress, Mortality stress scenarios)
- 1 Policy data file
- 3 Present value files (Base, Lapse stress, Mortality stress scenarios)
""")
with gr.Row():
with gr.Column():
gr.Markdown("### Upload Files")
cashflow_base = gr.File(label="Cashflows - Base Scenario", file_types=[".xlsx"])
cashflow_lapse = gr.File(label="Cashflows - Lapse Stress (+50%)", file_types=[".xlsx"])
cashflow_mort = gr.File(label="Cashflows - Mortality Stress (+15%)", file_types=[".xlsx"])
policy_data = gr.File(label="Policy Data", file_types=[".xlsx"])
pv_base = gr.File(label="Present Values - Base Scenario", file_types=[".xlsx"])
pv_lapse = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
pv_mort = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Tabs():
with gr.TabItem("Summary"):
summary_plot = gr.Image(label="Calibration Methods Comparison")
with gr.TabItem("Cashflow Calibration"):
gr.Markdown("### Results using Annual Cashflows as Calibration Variables")
with gr.Row():
cf_base_table = gr.Dataframe(label="Base Scenario Comparison")
cf_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
cf_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
cf_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
with gr.Row():
cf_pv_base = gr.Dataframe(label="Present Values - Base")
cf_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
cf_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
with gr.TabItem("Policy Attribute Calibration"):
gr.Markdown("### Results using Policy Attributes as Calibration Variables")
with gr.Row():
attr_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
attr_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
attr_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
attr_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
attr_pv_base = gr.Dataframe(label="Present Values - Base Scenario")
with gr.TabItem("Present Value Calibration"):
gr.Markdown("### Results using Present Values as Calibration Variables")
with gr.Row():
pv_cf_base = gr.Dataframe(label="Cashflows - Base Scenario")
pv_policy_attrs = gr.Dataframe(label="Policy Attributes Comparison")
pv_cashflow_plot = gr.Image(label="Cashflow Comparisons Across Scenarios")
pv_scatter_base = gr.Image(label="Scatter Plot - Base Scenario")
with gr.Row():
pv_pv_base = gr.Dataframe(label="Present Values - Base")
pv_pv_lapse = gr.Dataframe(label="Present Values - Lapse Stress")
pv_pv_mort = gr.Dataframe(label="Present Values - Mortality Stress")
def update_interface(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort):
if not all([cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort]):
return [None] * 17
results = process_files(cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort)
if "error" in results:
gr.Warning(results["error"])
return [None] * 17
return [
results.get('summary_plot'),
results.get('cf_base_table'),
results.get('cf_policy_attrs'),
results.get('cf_cashflow_plot'),
results.get('cf_scatter_base'),
results.get('cf_pv_base'),
results.get('cf_pv_lapse'),
results.get('cf_pv_mort'),
results.get('attr_cf_base'),
results.get('attr_policy_attrs'),
results.get('attr_cashflow_plot'),
results.get('attr_scatter_base'),
results.get('attr_pv_base'),
results.get('pv_cf_base'),
results.get('pv_policy_attrs'),
results.get('pv_cashflow_plot'),
results.get('pv_scatter_base'),
results.get('pv_pv_base'),
results.get('pv_pv_lapse'),
results.get('pv_pv_mort')
]
analyze_btn.click(
update_interface,
inputs=[cashflow_base, cashflow_lapse, cashflow_mort, policy_data, pv_base, pv_lapse, pv_mort],
outputs=[
summary_plot,
cf_base_table, cf_policy_attrs, cf_cashflow_plot, cf_scatter_base,
cf_pv_base, cf_pv_lapse, cf_pv_mort,
attr_cf_base, attr_policy_attrs, attr_cashflow_plot, attr_scatter_base, attr_pv_base,
pv_cf_base, pv_policy_attrs, pv_cashflow_plot, pv_scatter_base,
pv_pv_base, pv_pv_lapse, pv_pv_mort
]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()