Nos7 commited on
Commit
25ec159
·
verified ·
1 Parent(s): 293fd2e

Create mymodel.py

Browse files

adapter le code pour la librairie FHE

Files changed (1) hide show
  1. mymodel.py +170 -0
mymodel.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Dependencies
2
+ import yaml
3
+ from joblib import dump, load
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split, cross_val_score
6
+ from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
7
+ # Naive Bayes Approach
8
+ from sklearn.naive_bayes import MultinomialNB
9
+ # Trees Approach
10
+ from sklearn.tree import DecisionTreeClassifier
11
+ # Ensemble Approach
12
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
13
+ import seaborn as sn
14
+ import matplotlib.pyplot as plt
15
+
16
+
17
+ class DiseasePrediction:
18
+ # Initialize and Load the Config File
19
+ def __init__(self, model_name=None):
20
+ # Load Config File
21
+ try:
22
+ with open('./config.yaml', 'r') as f:
23
+ self.config = yaml.safe_load(f)
24
+ except Exception as e:
25
+ print("Error reading Config file...")
26
+
27
+ # Verbose
28
+ self.verbose = self.config['verbose']
29
+ # Load Training Data
30
+ self.train_features, self.train_labels, self.train_df = self._load_train_dataset()
31
+ # Load Test Data
32
+ self.test_features, self.test_labels, self.test_df = self._load_test_dataset()
33
+ # Feature Correlation in Training Data
34
+ self._feature_correlation(data_frame=self.train_df, show_fig=False)
35
+ # Model Definition
36
+ self.model_name = model_name
37
+ # Model Save Path
38
+ self.model_save_path = self.config['model_save_path']
39
+
40
+ # Function to Load Train Dataset
41
+ def _load_train_dataset(self):
42
+ df_train = pd.read_csv(self.config['dataset']['training_data_path'])
43
+ cols = df_train.columns
44
+ cols = cols[:-2]
45
+ train_features = df_train[cols]
46
+ train_labels = df_train['prognosis']
47
+
48
+ # Check for data sanity
49
+ assert (len(train_features.iloc[0]) == 132)
50
+ assert (len(train_labels) == train_features.shape[0])
51
+
52
+ if self.verbose:
53
+ print("Length of Training Data: ", df_train.shape)
54
+ print("Training Features: ", train_features.shape)
55
+ print("Training Labels: ", train_labels.shape)
56
+ return train_features, train_labels, df_train
57
+
58
+ # Function to Load Test Dataset
59
+ def _load_test_dataset(self):
60
+ df_test = pd.read_csv(self.config['dataset']['test_data_path'])
61
+ cols = df_test.columns
62
+ cols = cols[:-1]
63
+ test_features = df_test[cols]
64
+ test_labels = df_test['prognosis']
65
+
66
+ # Check for data sanity
67
+ assert (len(test_features.iloc[0]) == 132)
68
+ assert (len(test_labels) == test_features.shape[0])
69
+
70
+ if self.verbose:
71
+ print("Length of Test Data: ", df_test.shape)
72
+ print("Test Features: ", test_features.shape)
73
+ print("Test Labels: ", test_labels.shape)
74
+ return test_features, test_labels, df_test
75
+
76
+ # Features Correlation
77
+ def _feature_correlation(self, data_frame=None, show_fig=False):
78
+ # Get Feature Correlation
79
+ corr = data_frame.corr()
80
+ sn.heatmap(corr, square=True, annot=False, cmap="YlGnBu")
81
+ plt.title("Feature Correlation")
82
+ plt.tight_layout()
83
+ if show_fig:
84
+ plt.show()
85
+ plt.savefig('feature_correlation.png')
86
+
87
+ # Dataset Train Validation Split
88
+ def _train_val_split(self):
89
+ X_train, X_val, y_train, y_val = train_test_split(self.train_features, self.train_labels,
90
+ test_size=self.config['dataset']['validation_size'],
91
+ random_state=self.config['random_state'])
92
+
93
+ if self.verbose:
94
+ print("Number of Training Features: {0}\tNumber of Training Labels: {1}".format(len(X_train), len(y_train)))
95
+ print("Number of Validation Features: {0}\tNumber of Validation Labels: {1}".format(len(X_val), len(y_val)))
96
+ return X_train, y_train, X_val, y_val
97
+
98
+ # Model Selection
99
+ def select_model(self):
100
+ if self.model_name == 'mnb':
101
+ self.clf = MultinomialNB()
102
+ elif self.model_name == 'decision_tree':
103
+ self.clf = DecisionTreeClassifier(criterion=self.config['model']['decision_tree']['criterion'])
104
+ elif self.model_name == 'random_forest':
105
+ self.clf = RandomForestClassifier(n_estimators=self.config['model']['random_forest']['n_estimators'])
106
+ elif self.model_name == 'gradient_boost':
107
+ self.clf = GradientBoostingClassifier(n_estimators=self.config['model']['gradient_boost']['n_estimators'],
108
+ criterion=self.config['model']['gradient_boost']['criterion'])
109
+ return self.clf
110
+
111
+ # ML Model
112
+ def train_model(self):
113
+ # Get the Data
114
+ X_train, y_train, X_val, y_val = self._train_val_split()
115
+ classifier = self.select_model()
116
+ # Training the Model
117
+ classifier = classifier.fit(X_train, y_train)
118
+ # Trained Model Evaluation on Validation Dataset
119
+ confidence = classifier.score(X_val, y_val)
120
+ # Validation Data Prediction
121
+ y_pred = classifier.predict(X_val)
122
+ # Model Validation Accuracy
123
+ accuracy = accuracy_score(y_val, y_pred)
124
+ # Model Confusion Matrix
125
+ conf_mat = confusion_matrix(y_val, y_pred)
126
+ # Model Classification Report
127
+ clf_report = classification_report(y_val, y_pred)
128
+ # Model Cross Validation Score
129
+ score = cross_val_score(classifier, X_val, y_val, cv=3)
130
+
131
+ if self.verbose:
132
+ print('\nTraining Accuracy: ', confidence)
133
+ print('\nValidation Prediction: ', y_pred)
134
+ print('\nValidation Accuracy: ', accuracy)
135
+ print('\nValidation Confusion Matrix: \n', conf_mat)
136
+ print('\nCross Validation Score: \n', score)
137
+ print('\nClassification Report: \n', clf_report)
138
+
139
+ # Save Trained Model
140
+ dump(classifier, str(self.model_save_path + self.model_name + ".joblib"))
141
+
142
+ # Function to Make Predictions on Test Data
143
+ def make_prediction(self, saved_model_name=None, test_data=None):
144
+ try:
145
+ # Load Trained Model
146
+ clf = load(str(self.model_save_path + saved_model_name + ".joblib"))
147
+ except Exception as e:
148
+ print("Model not found...")
149
+
150
+ if test_data is not None:
151
+ result = clf.predict(test_data)
152
+ return result
153
+ else:
154
+ result = clf.predict(self.test_features)
155
+ accuracy = accuracy_score(self.test_labels, result)
156
+ clf_report = classification_report(self.test_labels, result)
157
+ return accuracy, clf_report
158
+
159
+
160
+ if __name__ == "__main__":
161
+ # Model Currently Training
162
+ current_model_name = 'decision_tree'
163
+ # Instantiate the Class
164
+ dp = DiseasePrediction(model_name=current_model_name)
165
+ # Train the Model
166
+ dp.train_model()
167
+ # Get Model Performance on Test Data
168
+ test_accuracy, classification_report = dp.make_prediction(saved_model_name=current_model_name)
169
+ print("Model Test Accuracy: ", test_accuracy)
170
+ print("Test Data Classification Report: \n", classification_report)