|
import os |
|
import joblib |
|
import pandas as pd |
|
from imblearn.over_sampling import SMOTE |
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score |
|
import gradio as gr |
|
|
|
|
|
model_dir = 'models' |
|
data_dir = 'datasets' |
|
|
|
preprocessor_path = os.path.join(model_dir, 'churn_preprocessor.joblib') |
|
loaded_preprocessor = joblib.load(preprocessor_path) |
|
|
|
model_names = [ |
|
'Ada Boost Classifier', |
|
'LGBM Classifier', |
|
'LogisticRegression', |
|
'XGBoost Classifier', |
|
] |
|
model_paths = {name: os.path.join(model_dir, f"{name.replace(' ', '')}.joblib") for name in model_names} |
|
|
|
|
|
models = {} |
|
for name, path in model_paths.items(): |
|
try: |
|
models[name] = joblib.load(path) |
|
except Exception as e: |
|
print(f"Error loading model {name} from {path}: {str(e)}") |
|
|
|
|
|
data_path = os.path.join(data_dir, 'cleaned_IT_customer_churn.csv') |
|
df = pd.read_csv(data_path) |
|
|
|
|
|
X = df.drop(columns=['Churn']) |
|
y = df['Churn'] |
|
|
|
|
|
input_choices = { |
|
'gender': ['Female', 'Male'], |
|
'internet_service': ['DSL', 'Fiber optic', 'No'], |
|
'contract': ['Month-to-month', 'One year', 'Two year'], |
|
'payment_method': ['Electronic check', 'Mailed check', 'Bank transfer (automatic)', 'Credit card (automatic)'], |
|
'others' : ['No', 'Yes'] |
|
} |
|
|
|
|
|
stats = df[['tenure', 'MonthlyCharges', 'TotalCharges']].agg(['mean', 'max']).reset_index() |
|
means = stats.loc[0] |
|
maxs = stats.loc[1] |
|
|
|
|
|
def calculate_metrics(y_true, y_pred): |
|
return { |
|
'Accuracy': accuracy_score(y_true, y_pred) * 100, |
|
'Recall': recall_score(y_true, y_pred) * 100, |
|
'F1 Score': f1_score(y_true, y_pred) * 100, |
|
'Precision': precision_score(y_true, y_pred) * 100, |
|
} |
|
|
|
|
|
def load_and_predict( |
|
gender, internet_service, contract, payment_method, tenure, monthly_charges, total_charges, |
|
senior_citizen, partner, dependents, phone_service, multiple_lines, online_security, online_backup, |
|
device_protection, tech_support, streaming_tv, streaming_movies, paperless_billing): |
|
|
|
|
|
try: |
|
sample = { |
|
'gender': int(gender == 'Male'), |
|
'SeniorCitizen': int(senior_citizen == 'Yes'), |
|
'Partner': int(partner == 'Yes'), |
|
'Dependents': int(dependents == 'Yes'), |
|
'tenure': int(tenure), |
|
'PhoneService': int(phone_service == 'Yes'), |
|
'MultipleLines': int(multiple_lines == 'Yes'), |
|
'InternetService': str(internet_service), |
|
'OnlineSecurity': int(online_security == 'Yes'), |
|
'OnlineBackup': int(online_backup == 'Yes'), |
|
'DeviceProtection': int(device_protection == 'Yes'), |
|
'TechSupport': int(tech_support == 'Yes'), |
|
'StreamingTV': int(streaming_tv == 'Yes'), |
|
'StreamingMovies': int(streaming_movies == 'Yes'), |
|
'Contract': str(contract), |
|
'PaperlessBilling': int(paperless_billing == 'Yes'), |
|
'PaymentMethod': str(payment_method), |
|
'MonthlyCharges': float(monthly_charges), |
|
'TotalCharges': float(total_charges) |
|
} |
|
|
|
sample_df = pd.DataFrame([sample]) |
|
sample_trans = loaded_preprocessor.transform(sample_df) |
|
X_trans = loaded_preprocessor.transform(X) |
|
|
|
|
|
X_resampled, y_resampled = SMOTE(random_state=42).fit_resample(X_trans, y) |
|
|
|
results = [] |
|
for name, model in models.items(): |
|
churn_pred = model.predict(sample_trans) |
|
y_resampled_pred = model.predict(X_resampled) |
|
metrics = calculate_metrics(y_resampled, y_resampled_pred) |
|
|
|
results.append({ |
|
'Model': name, |
|
'Predicted Churn': 'Yes' if churn_pred[0] == 1 else 'No', |
|
**metrics, |
|
}) |
|
|
|
return pd.DataFrame(results).sort_values(by='Accuracy', ascending=False).reset_index(drop=True) |
|
|
|
except Exception as e: |
|
return f"An error occurred during model loading or prediction: {str(e)}" |
|
|
|
|
|
input_components = [ |
|
gr.Radio(label="Gender", choices=input_choices['gender'], value=input_choices['gender'][0]), |
|
gr.Dropdown(label="Internet Service", choices=input_choices['internet_service'], value=input_choices['internet_service'][0]), |
|
gr.Dropdown(label="Contract", choices=input_choices['contract'], value=input_choices['contract'][0]), |
|
gr.Dropdown(label="Payment Method", choices=input_choices['payment_method'], value=input_choices['payment_method'][0]), |
|
gr.Slider(label="Tenure (Months)", minimum=0, maximum=int(maxs['tenure'] * 1.5), value=int(means['tenure'])), |
|
gr.Number(label="Monthly Charges", minimum=0.0, maximum=float(maxs['MonthlyCharges'] * 1.5), value=float(means['MonthlyCharges'])), |
|
gr.Number(label="Total Charges", minimum=0.0, maximum=float(maxs['TotalCharges'] * 1.5), value=float(means['TotalCharges'])), |
|
gr.Radio(label="Senior Citizen", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Partner", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Dependents", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Phone Service", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Multiple Lines", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Online Security", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Online Backup", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Device Protection", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Tech Support", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Streaming TV", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Streaming Movies", choices=input_choices['others'], value=input_choices['others'][0]), |
|
gr.Radio(label="Paperless Billing", choices=input_choices['others'], value=input_choices['others'][0]), |
|
] |
|
|
|
output_component = gr.DataFrame() |
|
|
|
|
|
gr.Interface( |
|
fn=load_and_predict, |
|
inputs=input_components, |
|
outputs=output_component, |
|
title="♻️ Customer Churn Prediction", |
|
description="Enter the following information to predict customer churn.", |
|
flagging_mode="never" |
|
).launch() |