|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib.font_manager |
|
from sklearn import svm |
|
import gradio as gr |
|
|
|
xx, yy = np.meshgrid(np.linspace(-5, 5, 500), np.linspace(-5, 5, 500)) |
|
|
|
X = 0.3 * np.random.randn(100, 2) |
|
X_train = np.r_[X + 2, X - 2] |
|
|
|
X = 0.3 * np.random.randn(20, 2) |
|
X_test = np.r_[X + 2, X - 2] |
|
|
|
X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2)) |
|
|
|
def createPlotAndPlotPoint(x_new=9, y_new=9): |
|
clf = svm.OneClassSVM(nu=0.1, kernel="rbf", gamma=0.1) |
|
clf.fit(X_train) |
|
y_pred_train = clf.predict(X_train) |
|
y_pred_test = clf.predict(X_test) |
|
y_pred_outliers = clf.predict(X_outliers) |
|
n_error_train = y_pred_train[y_pred_train == -1].size |
|
n_error_test = y_pred_test[y_pred_test == -1].size |
|
n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size |
|
|
|
|
|
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) |
|
Z = Z.reshape(xx.shape) |
|
plt.figure() |
|
plt.title("Novelty Detection") |
|
plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu) |
|
a = plt.contour(xx, yy, Z, levels=[0], linewidths=3, colors="darkred") |
|
plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred") |
|
|
|
s = 40 |
|
b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k") |
|
b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k") |
|
c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k") |
|
plt.axis("tight") |
|
plt.xlim((-5, 5)) |
|
plt.ylim((-5, 5)) |
|
plt.legend( |
|
[a.collections[0], b1, b2, c], |
|
[ |
|
"learned frontier", |
|
"training observations", |
|
"new regular observations", |
|
"new abnormal observations", |
|
], |
|
loc="upper left", |
|
prop=matplotlib.font_manager.FontProperties(size=11), |
|
) |
|
|
|
isAbnormal = (clf.predict([[x_new,y_new]])[0] == -1) |
|
markerfacecolor = "gold" if isAbnormal else "blueviolet" |
|
outputText = "abnormal" if isAbnormal else "regular" |
|
plt.plot(x_new, y_new, marker="o", markersize=15, markeredgecolor="m", markerfacecolor=markerfacecolor) |
|
|
|
plt.xlabel( |
|
"error train: %d/200 ; errors novel regular: %d/40 ; errors novel abnormal: %d/40" |
|
% (n_error_train, n_error_test, n_error_outliers) |
|
) |
|
return plt, outputText.capitalize() |
|
|
|
with gr.Blocks() as demo: |
|
link = "https://scikit-learn.org/stable/auto_examples/svm/plot_oneclass.html#sphx-glr-auto-examples-svm-plot-oneclass-py" |
|
gr.Markdown("# Novelty detection using One-class SVM") |
|
gr.Markdown(f"This demo is based on this [scikit-learn example]({link}).") |
|
gr.Markdown("In this demo, we use One-class SVM (Support Vector Machine) to learn the decision function for novelty detection.") |
|
gr.Markdown("Furthermore, we **test** the algorithm on new data that would be classified as similar or different to the training set.") |
|
|
|
gr.Markdown("#### You can define the coordinates of the new data point below!") |
|
|
|
x_new = gr.Slider(-5,5,0, label="X", info="Choose the X coordinate") |
|
y_new = gr.Slider(-5,5,0, label="Y", info="Choose the Y coordinate") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
plot = gr.Plot(label=f"Decision function plot") |
|
with gr.Column(scale=1): |
|
prediction = gr.Textbox(label="Is the new data point regular or abormal?") |
|
|
|
x_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction]) |
|
y_new.change(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction]) |
|
demo.load(createPlotAndPlotPoint, inputs=[x_new, y_new], outputs=[plot, prediction]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |