Spaces:
Runtime error
Runtime error
File size: 2,332 Bytes
9c323ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""
The implementation of the decision tree model for anomaly detection.
Authors:
LogPAI Team
Reference:
[1] Mike Chen, Alice X. Zheng, Jim Lloyd, Michael I. Jordan, Eric Brewer.
Failure Diagnosis Using Decision Trees. IEEE International Conference
on Autonomic Computing (ICAC), 2004.
"""
import numpy as np
from sklearn import tree
from ..utils import metrics
class DecisionTree(object):
def __init__(self, criterion='gini', max_depth=None, max_features=None, class_weight=None):
""" The Invariants Mining model for anomaly detection
Arguments
---------
See DecisionTreeClassifier API: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
Attributes
----------
classifier: object, the classifier for anomaly detection
"""
self.classifier = tree.DecisionTreeClassifier(criterion=criterion, max_depth=max_depth,
max_features=max_features, class_weight=class_weight)
def fit(self, X, y):
"""
Arguments
---------
X: ndarray, the event count matrix of shape num_instances-by-num_events
"""
print('====== Model summary ======')
self.classifier.fit(X, y)
def predict(self, X):
""" Predict anomalies with mined invariants
Arguments
---------
X: the input event count matrix
Returns
-------
y_pred: ndarray, the predicted label vector of shape (num_instances,)
"""
y_pred = self.classifier.predict(X)
return y_pred
def predict_proba(self, X):
""" Predict anomalies with mined invariants
Arguments
---------
X: the input event count matrix
Returns
-------
y_pred: ndarray, the predicted label vector of shape (num_instances,)
"""
y_pred = self.classifier.predict_proba(X)
return y_pred
def evaluate(self, X, y_true):
print('====== Evaluation summary ======')
y_pred = self.predict(X)
precision, recall, f1 = metrics(y_pred, y_true)
print('Precision: {:.3f}, recall: {:.3f}, F1-measure: {:.3f}\n'.format(precision, recall, f1))
return precision, recall, f1
|