Update mymodel.py
Browse files- mymodel.py +4 -4
mymodel.py
CHANGED
@@ -12,6 +12,8 @@ from sklearn.tree import DecisionTreeClassifier
|
|
12 |
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
13 |
import seaborn as sn
|
14 |
import matplotlib.pyplot as plt
|
|
|
|
|
15 |
|
16 |
|
17 |
class DiseasePrediction:
|
@@ -100,11 +102,9 @@ class DiseasePrediction:
|
|
100 |
if self.model_name == 'mnb':
|
101 |
self.clf = MultinomialNB()
|
102 |
elif self.model_name == 'decision_tree':
|
103 |
-
self.clf =
|
104 |
-
elif self.model_name == 'random_forest':
|
105 |
-
self.clf = RandomForestClassifier(n_estimators=self.config['model']['random_forest']['n_estimators'])
|
106 |
elif self.model_name == 'gradient_boost':
|
107 |
-
self.clf =
|
108 |
criterion=self.config['model']['gradient_boost']['criterion'])
|
109 |
return self.clf
|
110 |
|
|
|
12 |
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
13 |
import seaborn as sn
|
14 |
import matplotlib.pyplot as plt
|
15 |
+
from concrete.ml.sklearn import DecisionTreeClassifier as ConcreteDecisionTreeClassifier
|
16 |
+
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
|
17 |
|
18 |
|
19 |
class DiseasePrediction:
|
|
|
102 |
if self.model_name == 'mnb':
|
103 |
self.clf = MultinomialNB()
|
104 |
elif self.model_name == 'decision_tree':
|
105 |
+
self.clf = ConcreteDecisionTreeClassifier(criterion=self.config['model']['decision_tree']['criterion'])
|
|
|
|
|
106 |
elif self.model_name == 'gradient_boost':
|
107 |
+
self.clf = ConcreteXGBClassifier(n_estimators=self.config['model']['gradient_boost']['n_estimators'],
|
108 |
criterion=self.config['model']['gradient_boost']['criterion'])
|
109 |
return self.clf
|
110 |
|