AfshinMA commited on
Commit
bdb54f3
·
verified ·
1 Parent(s): b1c9c69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -151
app.py CHANGED
@@ -1,152 +1,150 @@
1
- import os
2
- import joblib
3
- import pandas as pd
4
- from imblearn.over_sampling import SMOTE
5
- from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
6
- import gradio as gr
7
-
8
- # Load models and preprocessor
9
- model_dir = 'models'
10
- data_dir = 'datasets'
11
-
12
- preprocessor_path = os.path.join(model_dir, 'churn_preprocessor.joblib')
13
- loaded_preprocessor = joblib.load(preprocessor_path)
14
-
15
- model_names = [
16
- 'Ada Boost Classifier',
17
- 'Extra Trees Classifier',
18
- 'Gradient Boosting Classifier',
19
- 'LGBM Classifier',
20
- 'LogisticRegression',
21
- 'RandomForestClassifier'
22
- 'XGBoost Classifier',
23
- ]
24
- model_paths = {name: os.path.join(model_dir, f"{name.replace(' ', '')}.joblib") for name in model_names}
25
-
26
- # Load models safely
27
- models = {}
28
- for name, path in model_paths.items():
29
- try:
30
- models[name] = joblib.load(path)
31
- except Exception as e:
32
- print(f"Error loading model {name} from {path}: {str(e)}")
33
-
34
- # Load dataset
35
- data_path = os.path.join(data_dir, 'cleaned_IT_customer_churn.csv')
36
- df = pd.read_csv(data_path)
37
-
38
- # Prepare features and target
39
- X = df.drop(columns=['Churn'])
40
- y = df['Churn']
41
-
42
- # Predefined input choices
43
- input_choices = {
44
- 'gender': ['Female', 'Male'],
45
- 'internet_service': ['DSL', 'Fiber optic', 'No'],
46
- 'contract': ['Month-to-month', 'One year', 'Two year'],
47
- 'payment_method': ['Electronic check', 'Mailed check', 'Bank transfer (automatic)', 'Credit card (automatic)'],
48
- 'others' : ['No', 'Yes']
49
- }
50
-
51
- # Pre-computed statistics for default values
52
- stats = df[['tenure', 'MonthlyCharges', 'TotalCharges']].agg(['mean', 'max']).reset_index()
53
- means = stats.loc[0]
54
- maxs = stats.loc[1]
55
-
56
- # Metrics calculation function
57
- def calculate_metrics(y_true, y_pred):
58
- return {
59
- 'Accuracy': accuracy_score(y_true, y_pred) * 100,
60
- 'Recall': recall_score(y_true, y_pred) * 100,
61
- 'F1 Score': f1_score(y_true, y_pred) * 100,
62
- 'Precision': precision_score(y_true, y_pred) * 100,
63
- }
64
-
65
- # Prediction and metrics evaluation function
66
- def load_and_predict(
67
- gender, internet_service, contract, payment_method, tenure, monthly_charges, total_charges,
68
- senior_citizen, partner, dependents, phone_service, multiple_lines, online_security, online_backup,
69
- device_protection, tech_support, streaming_tv, streaming_movies, paperless_billing):
70
-
71
- # Ensure inputs are not None
72
- try:
73
- sample = {
74
- 'gender': int(gender == 'Male'),
75
- 'SeniorCitizen': int(senior_citizen == 'Yes'),
76
- 'Partner': int(partner == 'Yes'),
77
- 'Dependents': int(dependents == 'Yes'),
78
- 'tenure': int(tenure),
79
- 'PhoneService': int(phone_service == 'Yes'),
80
- 'MultipleLines': int(multiple_lines == 'Yes'),
81
- 'InternetService': str(internet_service),
82
- 'OnlineSecurity': int(online_security == 'Yes'),
83
- 'OnlineBackup': int(online_backup == 'Yes'),
84
- 'DeviceProtection': int(device_protection == 'Yes'),
85
- 'TechSupport': int(tech_support == 'Yes'),
86
- 'StreamingTV': int(streaming_tv == 'Yes'),
87
- 'StreamingMovies': int(streaming_movies == 'Yes'),
88
- 'Contract': str(contract),
89
- 'PaperlessBilling': int(paperless_billing == 'Yes'),
90
- 'PaymentMethod': str(payment_method),
91
- 'MonthlyCharges': float(monthly_charges),
92
- 'TotalCharges': float(total_charges)
93
- }
94
-
95
- sample_df = pd.DataFrame([sample])
96
- sample_trans = loaded_preprocessor.transform(sample_df)
97
- X_trans = loaded_preprocessor.transform(X)
98
-
99
- # Using SMOTE to handle class imbalance
100
- X_resampled, y_resampled = SMOTE(random_state=42).fit_resample(X_trans, y)
101
-
102
- results = []
103
- for name, model in models.items():
104
- churn_pred = model.predict(sample_trans)
105
- y_resampled_pred = model.predict(X_resampled)
106
- metrics = calculate_metrics(y_resampled, y_resampled_pred)
107
-
108
- results.append({
109
- 'Model': name,
110
- 'Predicted Churn': 'Yes' if churn_pred[0] == 1 else 'No',
111
- **metrics,
112
- })
113
-
114
- return pd.DataFrame(results).sort_values(by='Accuracy', ascending=False).reset_index(drop=True)
115
-
116
- except Exception as e:
117
- return f"An error occurred during model loading or prediction: {str(e)}"
118
-
119
- # Gradio Interface setup
120
- input_components = [
121
- gr.Radio(label="Gender", choices=input_choices['gender'], value=input_choices['gender'][0]),
122
- gr.Dropdown(label="Internet Service", choices=input_choices['internet_service'], value=input_choices['internet_service'][0]),
123
- gr.Dropdown(label="Contract", choices=input_choices['contract'], value=input_choices['contract'][0]),
124
- gr.Dropdown(label="Payment Method", choices=input_choices['payment_method'], value=input_choices['payment_method'][0]),
125
- gr.Slider(label="Tenure (Months)", minimum=0, maximum=int(maxs['tenure'] * 1.5), value=int(means['tenure'])),
126
- gr.Number(label="Monthly Charges", minimum=0.0, maximum=float(maxs['MonthlyCharges'] * 1.5), value=float(means['MonthlyCharges'])),
127
- gr.Number(label="Total Charges", minimum=0.0, maximum=float(maxs['TotalCharges'] * 1.5), value=float(means['TotalCharges'])),
128
- gr.Radio(label="Senior Citizen", choices=input_choices['others'], value=input_choices['others'][0]),
129
- gr.Radio(label="Partner", choices=input_choices['others'], value=input_choices['others'][0]),
130
- gr.Radio(label="Dependents", choices=input_choices['others'], value=input_choices['others'][0]),
131
- gr.Radio(label="Phone Service", choices=input_choices['others'], value=input_choices['others'][0]),
132
- gr.Radio(label="Multiple Lines", choices=input_choices['others'], value=input_choices['others'][0]),
133
- gr.Radio(label="Online Security", choices=input_choices['others'], value=input_choices['others'][0]),
134
- gr.Radio(label="Online Backup", choices=input_choices['others'], value=input_choices['others'][0]),
135
- gr.Radio(label="Device Protection", choices=input_choices['others'], value=input_choices['others'][0]),
136
- gr.Radio(label="Tech Support", choices=input_choices['others'], value=input_choices['others'][0]),
137
- gr.Radio(label="Streaming TV", choices=input_choices['others'], value=input_choices['others'][0]),
138
- gr.Radio(label="Streaming Movies", choices=input_choices['others'], value=input_choices['others'][0]),
139
- gr.Radio(label="Paperless Billing", choices=input_choices['others'], value=input_choices['others'][0]),
140
- ]
141
-
142
- output_component = gr.DataFrame()
143
-
144
- # Launching the Gradio Interface
145
- gr.Interface(
146
- fn=load_and_predict,
147
- inputs=input_components,
148
- outputs=output_component,
149
- title="♻️ Customer Churn Prediction",
150
- description="Enter the following information to predict customer churn.",
151
- flagging_mode="never" # Replacing allow_flagging with flagging_mode
152
  ).launch()
 
1
+ import os
2
+ import joblib
3
+ import pandas as pd
4
+ from imblearn.over_sampling import SMOTE
5
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
6
+ import gradio as gr
7
+
8
+ # Load models and preprocessor
9
+ model_dir = 'models'
10
+ data_dir = 'datasets'
11
+
12
+ preprocessor_path = os.path.join(model_dir, 'churn_preprocessor.joblib')
13
+ loaded_preprocessor = joblib.load(preprocessor_path)
14
+
15
+ model_names = [
16
+ 'Ada Boost Classifier',
17
+ 'Gradient Boosting Classifier',
18
+ 'LGBM Classifier',
19
+ 'LogisticRegression',
20
+ 'XGBoost Classifier',
21
+ ]
22
+ model_paths = {name: os.path.join(model_dir, f"{name.replace(' ', '')}.joblib") for name in model_names}
23
+
24
+ # Load models safely
25
+ models = {}
26
+ for name, path in model_paths.items():
27
+ try:
28
+ models[name] = joblib.load(path)
29
+ except Exception as e:
30
+ print(f"Error loading model {name} from {path}: {str(e)}")
31
+
32
+ # Load dataset
33
+ data_path = os.path.join(data_dir, 'cleaned_IT_customer_churn.csv')
34
+ df = pd.read_csv(data_path)
35
+
36
+ # Prepare features and target
37
+ X = df.drop(columns=['Churn'])
38
+ y = df['Churn']
39
+
40
+ # Predefined input choices
41
+ input_choices = {
42
+ 'gender': ['Female', 'Male'],
43
+ 'internet_service': ['DSL', 'Fiber optic', 'No'],
44
+ 'contract': ['Month-to-month', 'One year', 'Two year'],
45
+ 'payment_method': ['Electronic check', 'Mailed check', 'Bank transfer (automatic)', 'Credit card (automatic)'],
46
+ 'others' : ['No', 'Yes']
47
+ }
48
+
49
+ # Pre-computed statistics for default values
50
+ stats = df[['tenure', 'MonthlyCharges', 'TotalCharges']].agg(['mean', 'max']).reset_index()
51
+ means = stats.loc[0]
52
+ maxs = stats.loc[1]
53
+
54
+ # Metrics calculation function
55
+ def calculate_metrics(y_true, y_pred):
56
+ return {
57
+ 'Accuracy': accuracy_score(y_true, y_pred) * 100,
58
+ 'Recall': recall_score(y_true, y_pred) * 100,
59
+ 'F1 Score': f1_score(y_true, y_pred) * 100,
60
+ 'Precision': precision_score(y_true, y_pred) * 100,
61
+ }
62
+
63
+ # Prediction and metrics evaluation function
64
+ def load_and_predict(
65
+ gender, internet_service, contract, payment_method, tenure, monthly_charges, total_charges,
66
+ senior_citizen, partner, dependents, phone_service, multiple_lines, online_security, online_backup,
67
+ device_protection, tech_support, streaming_tv, streaming_movies, paperless_billing):
68
+
69
+ # Ensure inputs are not None
70
+ try:
71
+ sample = {
72
+ 'gender': int(gender == 'Male'),
73
+ 'SeniorCitizen': int(senior_citizen == 'Yes'),
74
+ 'Partner': int(partner == 'Yes'),
75
+ 'Dependents': int(dependents == 'Yes'),
76
+ 'tenure': int(tenure),
77
+ 'PhoneService': int(phone_service == 'Yes'),
78
+ 'MultipleLines': int(multiple_lines == 'Yes'),
79
+ 'InternetService': str(internet_service),
80
+ 'OnlineSecurity': int(online_security == 'Yes'),
81
+ 'OnlineBackup': int(online_backup == 'Yes'),
82
+ 'DeviceProtection': int(device_protection == 'Yes'),
83
+ 'TechSupport': int(tech_support == 'Yes'),
84
+ 'StreamingTV': int(streaming_tv == 'Yes'),
85
+ 'StreamingMovies': int(streaming_movies == 'Yes'),
86
+ 'Contract': str(contract),
87
+ 'PaperlessBilling': int(paperless_billing == 'Yes'),
88
+ 'PaymentMethod': str(payment_method),
89
+ 'MonthlyCharges': float(monthly_charges),
90
+ 'TotalCharges': float(total_charges)
91
+ }
92
+
93
+ sample_df = pd.DataFrame([sample])
94
+ sample_trans = loaded_preprocessor.transform(sample_df)
95
+ X_trans = loaded_preprocessor.transform(X)
96
+
97
+ # Using SMOTE to handle class imbalance
98
+ X_resampled, y_resampled = SMOTE(random_state=42).fit_resample(X_trans, y)
99
+
100
+ results = []
101
+ for name, model in models.items():
102
+ churn_pred = model.predict(sample_trans)
103
+ y_resampled_pred = model.predict(X_resampled)
104
+ metrics = calculate_metrics(y_resampled, y_resampled_pred)
105
+
106
+ results.append({
107
+ 'Model': name,
108
+ 'Predicted Churn': 'Yes' if churn_pred[0] == 1 else 'No',
109
+ **metrics,
110
+ })
111
+
112
+ return pd.DataFrame(results).sort_values(by='Accuracy', ascending=False).reset_index(drop=True)
113
+
114
+ except Exception as e:
115
+ return f"An error occurred during model loading or prediction: {str(e)}"
116
+
117
+ # Gradio Interface setup
118
+ input_components = [
119
+ gr.Radio(label="Gender", choices=input_choices['gender'], value=input_choices['gender'][0]),
120
+ gr.Dropdown(label="Internet Service", choices=input_choices['internet_service'], value=input_choices['internet_service'][0]),
121
+ gr.Dropdown(label="Contract", choices=input_choices['contract'], value=input_choices['contract'][0]),
122
+ gr.Dropdown(label="Payment Method", choices=input_choices['payment_method'], value=input_choices['payment_method'][0]),
123
+ gr.Slider(label="Tenure (Months)", minimum=0, maximum=int(maxs['tenure'] * 1.5), value=int(means['tenure'])),
124
+ gr.Number(label="Monthly Charges", minimum=0.0, maximum=float(maxs['MonthlyCharges'] * 1.5), value=float(means['MonthlyCharges'])),
125
+ gr.Number(label="Total Charges", minimum=0.0, maximum=float(maxs['TotalCharges'] * 1.5), value=float(means['TotalCharges'])),
126
+ gr.Radio(label="Senior Citizen", choices=input_choices['others'], value=input_choices['others'][0]),
127
+ gr.Radio(label="Partner", choices=input_choices['others'], value=input_choices['others'][0]),
128
+ gr.Radio(label="Dependents", choices=input_choices['others'], value=input_choices['others'][0]),
129
+ gr.Radio(label="Phone Service", choices=input_choices['others'], value=input_choices['others'][0]),
130
+ gr.Radio(label="Multiple Lines", choices=input_choices['others'], value=input_choices['others'][0]),
131
+ gr.Radio(label="Online Security", choices=input_choices['others'], value=input_choices['others'][0]),
132
+ gr.Radio(label="Online Backup", choices=input_choices['others'], value=input_choices['others'][0]),
133
+ gr.Radio(label="Device Protection", choices=input_choices['others'], value=input_choices['others'][0]),
134
+ gr.Radio(label="Tech Support", choices=input_choices['others'], value=input_choices['others'][0]),
135
+ gr.Radio(label="Streaming TV", choices=input_choices['others'], value=input_choices['others'][0]),
136
+ gr.Radio(label="Streaming Movies", choices=input_choices['others'], value=input_choices['others'][0]),
137
+ gr.Radio(label="Paperless Billing", choices=input_choices['others'], value=input_choices['others'][0]),
138
+ ]
139
+
140
+ output_component = gr.DataFrame()
141
+
142
+ # Launching the Gradio Interface
143
+ gr.Interface(
144
+ fn=load_and_predict,
145
+ inputs=input_components,
146
+ outputs=output_component,
147
+ title="♻️ Customer Churn Prediction",
148
+ description="Enter the following information to predict customer churn.",
149
+ flagging_mode="never" # Replacing allow_flagging with flagging_mode
 
 
150
  ).launch()