MuskanMjn commited on
Commit
886a925
·
1 Parent(s): a3b8132

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.font_manager
4
+ from sklearn import svm
5
+
6
+ xx, yy = np.meshgrid(np.linspace(-5, 5, 500), np.linspace(-5, 5, 500))
7
+ # Generate train data
8
+ X = 0.3 * np.random.randn(100, 2)
9
+ X_train = np.r_[X + 2, X - 2]
10
+ # Generate some regular novel observations
11
+ X = 0.3 * np.random.randn(20, 2)
12
+ X_test = np.r_[X + 2, X - 2]
13
+ # Generate some abnormal novel observations
14
+ X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2))
15
+
16
+ def createPlotAndPlotPoint(x_new=9, y_new=9):
17
+ clf = svm.OneClassSVM(nu=0.1, kernel="rbf", gamma=0.1)
18
+ clf.fit(X_train)
19
+ y_pred_train = clf.predict(X_train)
20
+ y_pred_test = clf.predict(X_test)
21
+ y_pred_outliers = clf.predict(X_outliers)
22
+ n_error_train = y_pred_train[y_pred_train == -1].size
23
+ n_error_test = y_pred_test[y_pred_test == -1].size
24
+ n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size
25
+
26
+ # plot the line, the points, and the nearest vectors to the plane
27
+ Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
28
+ Z = Z.reshape(xx.shape)
29
+
30
+ plt.title("Novelty Detection")
31
+ plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu)
32
+ a = plt.contour(xx, yy, Z, levels=[0], linewidths=3, colors="darkred")
33
+ plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred")
34
+
35
+ s = 40
36
+ b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k")
37
+ b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k")
38
+ c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k")
39
+ plt.axis("tight")
40
+ plt.xlim((-5, 5))
41
+ plt.ylim((-5, 5))
42
+ plt.legend(
43
+ [a.collections[0], b1, b2, c],
44
+ [
45
+ "learned frontier",
46
+ "training observations",
47
+ "new regular observations",
48
+ "new abnormal observations",
49
+ ],
50
+ loc="upper left",
51
+ prop=matplotlib.font_manager.FontProperties(size=11),
52
+ )
53
+
54
+ isAbnormal = (clf.predict([[x_new,y_new]])[0] == -1)
55
+ markerfacecolor = "gold" if isAbnormal else "blueviolet"
56
+ outputText = "abnormal" if isAbnormal else "regular"
57
+ plt.plot(x_new, y_new, marker="o", markersize=15, markeredgecolor="m", markerfacecolor=markerfacecolor)
58
+
59
+ plt.xlabel(
60
+ "error train: %d/200 ; errors novel regular: %d/40 ; errors novel abnormal: %d/40"
61
+ % (n_error_train, n_error_test, n_error_outliers)
62
+ )
63
+ return plt, outputText.capitalize()
64
+
65
+ with gr.Blocks() as demo:
66
+ link = "https://scikit-learn.org/stable/auto_examples/svm/plot_oneclass.html#sphx-glr-auto-examples-svm-plot-oneclass-py"
67
+ gr.Markdown("# Novelty detection using One-class SVM")
68
+ gr.Markdown(f"## This demo is based on this [scikit-learn example]({link}).")
69
+ gr.Markdown("In this demo, we use One-class SVM (Support Vector Machine) to learn the decision function for novelty detection.")
70
+ gr.Markdown("Furthermore, we **test** the algorithm on new data that would be classified as similar or different to the training set.")
71
+
72
+ gr.Markdown("#### You can define the coordinates of the new data point below!")
73
+
74
+ x_new = gr.Slider(-5,5,0, label="X", info="Choose the X coordinate")
75
+ y_new = gr.Slider(-5,5,0, label="Y", info="Choose the Y coordinate")
76
+
77
+ with gr.Row():
78
+ with gr.Column(scale=2):
79
+ plot = gr.Plot(label=f"Decision function plot")
80
+ with gr.Column(scale=1):
81
+ prediction = gr.Textbox(label="Is the new data point regular or abormal?")
82
+
83
+ x_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
84
+ y_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
85
+ demo.load(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction])
86
+
87
+ if __name__ == "__main__":
88
+ demo.launch()