import gradio as gr import pandas as pd import numpy as np from sklearn.cluster import KMeans from sklearn.metrics import r2_score, pairwise_distances_argmin_min import matplotlib.pyplot as plt import io def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters): # Basic checks and reads try: # Use policy_file.name which is the path to the temporary file Gradio creates policy_df = pd.read_csv(policy_file.name, index_col=0) cashflow_df = pd.read_csv(cashflow_file.name, index_col=0) pv_df = pd.read_csv(pv_file.name, index_col=0) except Exception as e: return (None, None, None, f"Error reading CSV files: {e}. Ensure files are CSVs and the first column is the index (e.g., Policy ID).") # Use policy attributes for clustering # Ensure these column names match your policy data CSV required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration'] if not all(col in policy_df.columns for col in required_cols): missing_cols = [col for col in required_cols if col not in policy_df.columns] return (None, None, None, f"Policy data missing required columns: {missing_cols}. Please ensure your policy CSV has these columns.") X = policy_df[required_cols].fillna(0) # Simple imputation # Handle cases with zero standard deviation (e.g., if a column has all same values after fillna) X_std = X.std() if (X_std == 0).any(): zero_std_cols = X_std[X_std == 0].index.tolist() return (None, None, None, f"Error: Columns {zero_std_cols} have zero standard deviation after fillna(0). Cannot scale these columns. Please check your data.") X_scaled = (X - X.mean()) / X_std # Cluster try: kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init=10) kmeans.fit(X_scaled) policy_df['Cluster'] = kmeans.labels_ except Exception as e: return (None, None, None, f"Clustering error: {e}") # Select model points as closest to cluster centers centers = kmeans.cluster_centers_ closest, _ = pairwise_distances_argmin_min(centers, X_scaled) model_points = policy_df.iloc[closest].copy() # Calculate weights (count per cluster) counts = policy_df['Cluster'].value_counts() model_points['Weight'] = model_points['Cluster'].map(counts) # Ensure model_points.index are valid for cashflow_df and pv_df if not model_points.index.isin(cashflow_df.index).all(): return (None, None, None, "Error: Model point indices not found in cashflow data. Ensure Policy IDs match.") if not model_points.index.isin(pv_df.index).all(): return (None, None, None, "Error: Model point indices not found in PV data. Ensure Policy IDs match.") # Create CSV for download csv_buffer = io.StringIO() model_points.to_csv(csv_buffer) # index=True by default, which is good if index is PolicyID csv_data = csv_buffer.getvalue() # Aggregate cashflows weighted by cluster counts # Ensure model_points['Weight'] is numeric for multiplication model_points['Weight'] = pd.to_numeric(model_points['Weight'], errors='coerce').fillna(1) proxy_cashflows_df = cashflow_df.loc[model_points.index] proxy_cashflows = proxy_cashflows_df.multiply(model_points['Weight'].values, axis=0).sum() seriatim_cashflows = cashflow_df.sum() # Plot aggregated cashflows fig, ax = plt.subplots(figsize=(8,4)) seriatim_cashflows.plot(ax=ax, label='Seriatim Cashflows') proxy_cashflows.plot(ax=ax, label='Proxy Cashflows', linestyle='--') ax.set_title('Aggregated Cashflows Comparison') ax.legend() ax.grid(True) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) cashflow_plot = buf.read() # Aggregate present values weighted proxy_pv_df = pv_df.loc[model_points.index] # Assuming pv_df has one column of PVs, or sum all columns if multiple if proxy_pv_df.shape[1] > 1: proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().sum() seriatim_pv = pv_df.sum().sum() else: proxy_pv = proxy_pv_df.multiply(model_points['Weight'].values, axis=0).sum().iloc[0] seriatim_pv = pv_df.sum().iloc[0] # Present Value comparison plot (bar) fig2, ax2 = plt.subplots(figsize=(5,4)) ax2.bar(['Seriatim PV', 'Proxy PV'], [seriatim_pv, proxy_pv], color=['blue', 'orange']) ax2.set_title('Aggregated Present Values') ax2.grid(axis='y') plt.tight_layout() buf2 = io.BytesIO() plt.savefig(buf2, format='png') plt.close(fig2) buf2.seek(0) pv_plot = buf2.read() # Accuracy metrics common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index) if not common_idx.empty: r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx]) else: r2 = float('nan') # Or handle as error pv_error = abs(proxy_pv - seriatim_pv) / seriatim_pv * 100 if seriatim_pv != 0 else float('inf') metrics_text = ( f"R-squared for aggregated cashflows: {r2:.4f}\n" f"Absolute percentage error in present value: {pv_error:.4f}%" ) return csv_data, cashflow_plot, pv_plot, metrics_text with gr.Blocks() as demo: gr.Markdown("# Actuarial Model Point Selection (CSV Upload)") with gr.Row(): with gr.Column(): policy_input = gr.File(label="Upload Policy Data (CSV with PolicyID as first column)") cashflow_input = gr.File(label="Upload Cashflow Data (CSV with PolicyID as first column)") pv_input = gr.File(label="Upload Present Value Data (CSV with PolicyID as first column)") clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="Number of Model Points") run_btn = gr.Button("Run Clustering") with gr.Column(): output_csv = gr.Textbox(label="Model Points CSV Output", lines=10, interactive=False) cashflow_img = gr.Image(label="Aggregated Cashflows Comparison", type="pil") # Using PIL for better compatibility pv_img = gr.Image(label="Aggregated Present Values Comparison", type="pil") metrics_box = gr.Textbox(label="Accuracy Metrics", lines=4, interactive=False) run_btn.click( cluster_analysis, inputs=[policy_input, cashflow_input, pv_input, clusters_input], outputs=[output_csv, cashflow_img, pv_img, metrics_box] ) if __name__ == '__main__': demo.launch(debug=True)