Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib.font_manager
|
4 |
+
from sklearn import svm
|
5 |
+
|
6 |
+
xx, yy = np.meshgrid(np.linspace(-5, 5, 500), np.linspace(-5, 5, 500))
|
7 |
+
# Generate train data
|
8 |
+
X = 0.3 * np.random.randn(100, 2)
|
9 |
+
X_train = np.r_[X + 2, X - 2]
|
10 |
+
# Generate some regular novel observations
|
11 |
+
X = 0.3 * np.random.randn(20, 2)
|
12 |
+
X_test = np.r_[X + 2, X - 2]
|
13 |
+
# Generate some abnormal novel observations
|
14 |
+
X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2))
|
15 |
+
|
16 |
+
def createPlotAndPlotPoint(x_new=9, y_new=9):
|
17 |
+
clf = svm.OneClassSVM(nu=0.1, kernel="rbf", gamma=0.1)
|
18 |
+
clf.fit(X_train)
|
19 |
+
y_pred_train = clf.predict(X_train)
|
20 |
+
y_pred_test = clf.predict(X_test)
|
21 |
+
y_pred_outliers = clf.predict(X_outliers)
|
22 |
+
n_error_train = y_pred_train[y_pred_train == -1].size
|
23 |
+
n_error_test = y_pred_test[y_pred_test == -1].size
|
24 |
+
n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size
|
25 |
+
|
26 |
+
# plot the line, the points, and the nearest vectors to the plane
|
27 |
+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
|
28 |
+
Z = Z.reshape(xx.shape)
|
29 |
+
|
30 |
+
plt.title("Novelty Detection")
|
31 |
+
plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu)
|
32 |
+
a = plt.contour(xx, yy, Z, levels=[0], linewidths=3, colors="darkred")
|
33 |
+
plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred")
|
34 |
+
|
35 |
+
s = 40
|
36 |
+
b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k")
|
37 |
+
b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k")
|
38 |
+
c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k")
|
39 |
+
plt.axis("tight")
|
40 |
+
plt.xlim((-5, 5))
|
41 |
+
plt.ylim((-5, 5))
|
42 |
+
plt.legend(
|
43 |
+
[a.collections[0], b1, b2, c],
|
44 |
+
[
|
45 |
+
"learned frontier",
|
46 |
+
"training observations",
|
47 |
+
"new regular observations",
|
48 |
+
"new abnormal observations",
|
49 |
+
],
|
50 |
+
loc="upper left",
|
51 |
+
prop=matplotlib.font_manager.FontProperties(size=11),
|
52 |
+
)
|
53 |
+
|
54 |
+
isAbnormal = (clf.predict([[x_new,y_new]])[0] == -1)
|
55 |
+
markerfacecolor = "gold" if isAbnormal else "blueviolet"
|
56 |
+
outputText = "abnormal" if isAbnormal else "regular"
|
57 |
+
plt.plot(x_new, y_new, marker="o", markersize=15, markeredgecolor="m", markerfacecolor=markerfacecolor)
|
58 |
+
|
59 |
+
plt.xlabel(
|
60 |
+
"error train: %d/200 ; errors novel regular: %d/40 ; errors novel abnormal: %d/40"
|
61 |
+
% (n_error_train, n_error_test, n_error_outliers)
|
62 |
+
)
|
63 |
+
return plt, outputText.capitalize()
|
64 |
+
|
65 |
+
with gr.Blocks() as demo:
|
66 |
+
link = "https://scikit-learn.org/stable/auto_examples/svm/plot_oneclass.html#sphx-glr-auto-examples-svm-plot-oneclass-py"
|
67 |
+
gr.Markdown("# Novelty detection using One-class SVM")
|
68 |
+
gr.Markdown(f"## This demo is based on this [scikit-learn example]({link}).")
|
69 |
+
gr.Markdown("In this demo, we use One-class SVM (Support Vector Machine) to learn the decision function for novelty detection.")
|
70 |
+
gr.Markdown("Furthermore, we **test** the algorithm on new data that would be classified as similar or different to the training set.")
|
71 |
+
|
72 |
+
gr.Markdown("#### You can define the coordinates of the new data point below!")
|
73 |
+
|
74 |
+
x_new = gr.Slider(-5,5,0, label="X", info="Choose the X coordinate")
|
75 |
+
y_new = gr.Slider(-5,5,0, label="Y", info="Choose the Y coordinate")
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column(scale=2):
|
79 |
+
plot = gr.Plot(label=f"Decision function plot")
|
80 |
+
with gr.Column(scale=1):
|
81 |
+
prediction = gr.Textbox(label="Is the new data point regular or abormal?")
|
82 |
+
|
83 |
+
x_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
|
84 |
+
y_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
|
85 |
+
demo.load(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
demo.launch()
|