Nos7 commited on
Commit
aeb2cab
·
verified ·
1 Parent(s): 25ec159

Update mymodel.py

Browse files
Files changed (1) hide show
  1. mymodel.py +4 -4
mymodel.py CHANGED
@@ -12,6 +12,8 @@ from sklearn.tree import DecisionTreeClassifier
12
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
13
  import seaborn as sn
14
  import matplotlib.pyplot as plt
 
 
15
 
16
 
17
  class DiseasePrediction:
@@ -100,11 +102,9 @@ class DiseasePrediction:
100
  if self.model_name == 'mnb':
101
  self.clf = MultinomialNB()
102
  elif self.model_name == 'decision_tree':
103
- self.clf = DecisionTreeClassifier(criterion=self.config['model']['decision_tree']['criterion'])
104
- elif self.model_name == 'random_forest':
105
- self.clf = RandomForestClassifier(n_estimators=self.config['model']['random_forest']['n_estimators'])
106
  elif self.model_name == 'gradient_boost':
107
- self.clf = GradientBoostingClassifier(n_estimators=self.config['model']['gradient_boost']['n_estimators'],
108
  criterion=self.config['model']['gradient_boost']['criterion'])
109
  return self.clf
110
 
 
12
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
13
  import seaborn as sn
14
  import matplotlib.pyplot as plt
15
+ from concrete.ml.sklearn import DecisionTreeClassifier as ConcreteDecisionTreeClassifier
16
+ from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
17
 
18
 
19
  class DiseasePrediction:
 
102
  if self.model_name == 'mnb':
103
  self.clf = MultinomialNB()
104
  elif self.model_name == 'decision_tree':
105
+ self.clf = ConcreteDecisionTreeClassifier(criterion=self.config['model']['decision_tree']['criterion'])
 
 
106
  elif self.model_name == 'gradient_boost':
107
+ self.clf = ConcreteXGBClassifier(n_estimators=self.config['model']['gradient_boost']['n_estimators'],
108
  criterion=self.config['model']['gradient_boost']['criterion'])
109
  return self.clf
110