Update app.py
Browse files
app.py
CHANGED
@@ -1,152 +1,150 @@
|
|
1 |
-
import os
|
2 |
-
import joblib
|
3 |
-
import pandas as pd
|
4 |
-
from imblearn.over_sampling import SMOTE
|
5 |
-
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
6 |
-
import gradio as gr
|
7 |
-
|
8 |
-
# Load models and preprocessor
|
9 |
-
model_dir = 'models'
|
10 |
-
data_dir = 'datasets'
|
11 |
-
|
12 |
-
preprocessor_path = os.path.join(model_dir, 'churn_preprocessor.joblib')
|
13 |
-
loaded_preprocessor = joblib.load(preprocessor_path)
|
14 |
-
|
15 |
-
model_names = [
|
16 |
-
'Ada Boost Classifier',
|
17 |
-
'
|
18 |
-
'
|
19 |
-
'
|
20 |
-
'
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
'
|
45 |
-
'
|
46 |
-
'
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
'
|
60 |
-
'
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
'
|
75 |
-
'
|
76 |
-
'
|
77 |
-
'
|
78 |
-
'
|
79 |
-
'
|
80 |
-
'
|
81 |
-
'
|
82 |
-
'
|
83 |
-
'
|
84 |
-
'
|
85 |
-
'
|
86 |
-
'
|
87 |
-
'
|
88 |
-
'
|
89 |
-
'
|
90 |
-
'
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
gr.
|
122 |
-
gr.Dropdown(label="
|
123 |
-
gr.
|
124 |
-
gr.
|
125 |
-
gr.
|
126 |
-
gr.
|
127 |
-
gr.
|
128 |
-
gr.Radio(label="
|
129 |
-
gr.Radio(label="
|
130 |
-
gr.Radio(label="
|
131 |
-
gr.Radio(label="
|
132 |
-
gr.Radio(label="
|
133 |
-
gr.Radio(label="
|
134 |
-
gr.Radio(label="
|
135 |
-
gr.Radio(label="
|
136 |
-
gr.Radio(label="
|
137 |
-
gr.Radio(label="
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
description="Enter the following information to predict customer churn.",
|
151 |
-
flagging_mode="never" # Replacing allow_flagging with flagging_mode
|
152 |
).launch()
|
|
|
1 |
+
import os
|
2 |
+
import joblib
|
3 |
+
import pandas as pd
|
4 |
+
from imblearn.over_sampling import SMOTE
|
5 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Load models and preprocessor
|
9 |
+
model_dir = 'models'
|
10 |
+
data_dir = 'datasets'
|
11 |
+
|
12 |
+
preprocessor_path = os.path.join(model_dir, 'churn_preprocessor.joblib')
|
13 |
+
loaded_preprocessor = joblib.load(preprocessor_path)
|
14 |
+
|
15 |
+
model_names = [
|
16 |
+
'Ada Boost Classifier',
|
17 |
+
'Gradient Boosting Classifier',
|
18 |
+
'LGBM Classifier',
|
19 |
+
'LogisticRegression',
|
20 |
+
'XGBoost Classifier',
|
21 |
+
]
|
22 |
+
model_paths = {name: os.path.join(model_dir, f"{name.replace(' ', '')}.joblib") for name in model_names}
|
23 |
+
|
24 |
+
# Load models safely
|
25 |
+
models = {}
|
26 |
+
for name, path in model_paths.items():
|
27 |
+
try:
|
28 |
+
models[name] = joblib.load(path)
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Error loading model {name} from {path}: {str(e)}")
|
31 |
+
|
32 |
+
# Load dataset
|
33 |
+
data_path = os.path.join(data_dir, 'cleaned_IT_customer_churn.csv')
|
34 |
+
df = pd.read_csv(data_path)
|
35 |
+
|
36 |
+
# Prepare features and target
|
37 |
+
X = df.drop(columns=['Churn'])
|
38 |
+
y = df['Churn']
|
39 |
+
|
40 |
+
# Predefined input choices
|
41 |
+
input_choices = {
|
42 |
+
'gender': ['Female', 'Male'],
|
43 |
+
'internet_service': ['DSL', 'Fiber optic', 'No'],
|
44 |
+
'contract': ['Month-to-month', 'One year', 'Two year'],
|
45 |
+
'payment_method': ['Electronic check', 'Mailed check', 'Bank transfer (automatic)', 'Credit card (automatic)'],
|
46 |
+
'others' : ['No', 'Yes']
|
47 |
+
}
|
48 |
+
|
49 |
+
# Pre-computed statistics for default values
|
50 |
+
stats = df[['tenure', 'MonthlyCharges', 'TotalCharges']].agg(['mean', 'max']).reset_index()
|
51 |
+
means = stats.loc[0]
|
52 |
+
maxs = stats.loc[1]
|
53 |
+
|
54 |
+
# Metrics calculation function
|
55 |
+
def calculate_metrics(y_true, y_pred):
|
56 |
+
return {
|
57 |
+
'Accuracy': accuracy_score(y_true, y_pred) * 100,
|
58 |
+
'Recall': recall_score(y_true, y_pred) * 100,
|
59 |
+
'F1 Score': f1_score(y_true, y_pred) * 100,
|
60 |
+
'Precision': precision_score(y_true, y_pred) * 100,
|
61 |
+
}
|
62 |
+
|
63 |
+
# Prediction and metrics evaluation function
|
64 |
+
def load_and_predict(
|
65 |
+
gender, internet_service, contract, payment_method, tenure, monthly_charges, total_charges,
|
66 |
+
senior_citizen, partner, dependents, phone_service, multiple_lines, online_security, online_backup,
|
67 |
+
device_protection, tech_support, streaming_tv, streaming_movies, paperless_billing):
|
68 |
+
|
69 |
+
# Ensure inputs are not None
|
70 |
+
try:
|
71 |
+
sample = {
|
72 |
+
'gender': int(gender == 'Male'),
|
73 |
+
'SeniorCitizen': int(senior_citizen == 'Yes'),
|
74 |
+
'Partner': int(partner == 'Yes'),
|
75 |
+
'Dependents': int(dependents == 'Yes'),
|
76 |
+
'tenure': int(tenure),
|
77 |
+
'PhoneService': int(phone_service == 'Yes'),
|
78 |
+
'MultipleLines': int(multiple_lines == 'Yes'),
|
79 |
+
'InternetService': str(internet_service),
|
80 |
+
'OnlineSecurity': int(online_security == 'Yes'),
|
81 |
+
'OnlineBackup': int(online_backup == 'Yes'),
|
82 |
+
'DeviceProtection': int(device_protection == 'Yes'),
|
83 |
+
'TechSupport': int(tech_support == 'Yes'),
|
84 |
+
'StreamingTV': int(streaming_tv == 'Yes'),
|
85 |
+
'StreamingMovies': int(streaming_movies == 'Yes'),
|
86 |
+
'Contract': str(contract),
|
87 |
+
'PaperlessBilling': int(paperless_billing == 'Yes'),
|
88 |
+
'PaymentMethod': str(payment_method),
|
89 |
+
'MonthlyCharges': float(monthly_charges),
|
90 |
+
'TotalCharges': float(total_charges)
|
91 |
+
}
|
92 |
+
|
93 |
+
sample_df = pd.DataFrame([sample])
|
94 |
+
sample_trans = loaded_preprocessor.transform(sample_df)
|
95 |
+
X_trans = loaded_preprocessor.transform(X)
|
96 |
+
|
97 |
+
# Using SMOTE to handle class imbalance
|
98 |
+
X_resampled, y_resampled = SMOTE(random_state=42).fit_resample(X_trans, y)
|
99 |
+
|
100 |
+
results = []
|
101 |
+
for name, model in models.items():
|
102 |
+
churn_pred = model.predict(sample_trans)
|
103 |
+
y_resampled_pred = model.predict(X_resampled)
|
104 |
+
metrics = calculate_metrics(y_resampled, y_resampled_pred)
|
105 |
+
|
106 |
+
results.append({
|
107 |
+
'Model': name,
|
108 |
+
'Predicted Churn': 'Yes' if churn_pred[0] == 1 else 'No',
|
109 |
+
**metrics,
|
110 |
+
})
|
111 |
+
|
112 |
+
return pd.DataFrame(results).sort_values(by='Accuracy', ascending=False).reset_index(drop=True)
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
return f"An error occurred during model loading or prediction: {str(e)}"
|
116 |
+
|
117 |
+
# Gradio Interface setup
|
118 |
+
input_components = [
|
119 |
+
gr.Radio(label="Gender", choices=input_choices['gender'], value=input_choices['gender'][0]),
|
120 |
+
gr.Dropdown(label="Internet Service", choices=input_choices['internet_service'], value=input_choices['internet_service'][0]),
|
121 |
+
gr.Dropdown(label="Contract", choices=input_choices['contract'], value=input_choices['contract'][0]),
|
122 |
+
gr.Dropdown(label="Payment Method", choices=input_choices['payment_method'], value=input_choices['payment_method'][0]),
|
123 |
+
gr.Slider(label="Tenure (Months)", minimum=0, maximum=int(maxs['tenure'] * 1.5), value=int(means['tenure'])),
|
124 |
+
gr.Number(label="Monthly Charges", minimum=0.0, maximum=float(maxs['MonthlyCharges'] * 1.5), value=float(means['MonthlyCharges'])),
|
125 |
+
gr.Number(label="Total Charges", minimum=0.0, maximum=float(maxs['TotalCharges'] * 1.5), value=float(means['TotalCharges'])),
|
126 |
+
gr.Radio(label="Senior Citizen", choices=input_choices['others'], value=input_choices['others'][0]),
|
127 |
+
gr.Radio(label="Partner", choices=input_choices['others'], value=input_choices['others'][0]),
|
128 |
+
gr.Radio(label="Dependents", choices=input_choices['others'], value=input_choices['others'][0]),
|
129 |
+
gr.Radio(label="Phone Service", choices=input_choices['others'], value=input_choices['others'][0]),
|
130 |
+
gr.Radio(label="Multiple Lines", choices=input_choices['others'], value=input_choices['others'][0]),
|
131 |
+
gr.Radio(label="Online Security", choices=input_choices['others'], value=input_choices['others'][0]),
|
132 |
+
gr.Radio(label="Online Backup", choices=input_choices['others'], value=input_choices['others'][0]),
|
133 |
+
gr.Radio(label="Device Protection", choices=input_choices['others'], value=input_choices['others'][0]),
|
134 |
+
gr.Radio(label="Tech Support", choices=input_choices['others'], value=input_choices['others'][0]),
|
135 |
+
gr.Radio(label="Streaming TV", choices=input_choices['others'], value=input_choices['others'][0]),
|
136 |
+
gr.Radio(label="Streaming Movies", choices=input_choices['others'], value=input_choices['others'][0]),
|
137 |
+
gr.Radio(label="Paperless Billing", choices=input_choices['others'], value=input_choices['others'][0]),
|
138 |
+
]
|
139 |
+
|
140 |
+
output_component = gr.DataFrame()
|
141 |
+
|
142 |
+
# Launching the Gradio Interface
|
143 |
+
gr.Interface(
|
144 |
+
fn=load_and_predict,
|
145 |
+
inputs=input_components,
|
146 |
+
outputs=output_component,
|
147 |
+
title="♻️ Customer Churn Prediction",
|
148 |
+
description="Enter the following information to predict customer churn.",
|
149 |
+
flagging_mode="never" # Replacing allow_flagging with flagging_mode
|
|
|
|
|
150 |
).launch()
|