marik0 commited on
Commit
f7996e9
·
1 Parent(s): 31fa95f

First attempt

Browse files
Files changed (1) hide show
  1. app.py +88 -4
app.py CHANGED
@@ -1,8 +1,92 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
 
8
- demo.launch()
 
1
  import gradio as gr
2
 
3
+ from sklearn.datasets import make_classification
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.ensemble import RandomForestClassifier
6
+ from sklearn.inspection import permutation_importance
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+
12
+ def create_dataset():
13
+ X, y = make_classification(
14
+ n_samples=1000,
15
+ n_features=10,
16
+ n_informative=3,
17
+ n_redundant=0,
18
+ n_repeated=0,
19
+ n_classes=2,
20
+ random_state=0,
21
+ shuffle=False,
22
+ )
23
+
24
+ X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
25
+ return X_train, X_test, y_train, y_test
26
+
27
+ def train_model():
28
+
29
+ X_train, X_test, y_train, y_test = create_dataset()
30
+
31
+ feature_names = [f"feature {i}" for i in range(X_train.shape[1])]
32
+ forest = RandomForestClassifier(random_state=0)
33
+ forest.fit(X_train, y_train)
34
+
35
+ return forest, feature_names, X_test, y_test
36
+
37
+
38
+ def plot_mean_decrease(clf, feature_names):
39
+ importances = clf.feature_importances_
40
+ std = np.std([tree.feature_importances_ for tree in clf.estimators_], axis=0)
41
+
42
+ forest_importances = pd.Series(importances, index=feature_names)
43
+
44
+ fig, ax = plt.subplots()
45
+ forest_importances.plot.bar(yerr=std, ax=ax)
46
+ ax.set_title("Feature importances using MDI")
47
+ ax.set_ylabel("Mean decrease in impurity")
48
+ fig.tight_layout()
49
+
50
+ return fig
51
+
52
+ def plot_feature_perm(clf, feature_names, X_test, y_test):
53
+ result = permutation_importance(
54
+ clf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
55
+ )
56
+ forest_importances = pd.Series(result.importances_mean, index=feature_names)
57
+
58
+ fig, ax = plt.subplots()
59
+ forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
60
+ ax.set_title("Feature importances using permutation on full model")
61
+ ax.set_ylabel("Mean accuracy decrease")
62
+ fig.tight_layout()
63
+
64
+ return fig
65
+
66
+
67
+
68
+ title = "Feature importances with a forest of trees 🌳"
69
+ description = """This example shows the use of a forest of trees to evaluate the importance of features on an artificial classification task.
70
+ The blue bars are the feature importances of the forest, along with their inter-trees variability represented by the error bars.
71
+ """
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown(f"## {title}")
75
+ gr.Markdown(description)
76
+
77
+ # with gr.Column():
78
+ clf, feature_names, X_test, y_test = train_model()
79
+
80
+ with gr.Row():
81
+ plot = gr.Plot(plot_mean_decrease(clf, feature_names))
82
+ plot2 = gr.Plot(plot_feature_perm(clf, feature_names, X_test, y_test))
83
+
84
+ # input_data = gr.Dropdown(choices=feature_names, label="Feature", value="body-mass index")
85
+ # coef = gr.Textbox(label="Coefficients")
86
+ # mse = gr.Textbox(label="Mean squared error (MSE)")
87
+ # r2 = gr.Textbox(label="R2 score")
88
+
89
+ # input_data.change(fn=train_model, inputs=[input_data], outputs=[plot, coef, mse, r2], queue=False)
90
 
 
91
 
92
+ demo.launch(enable_queue=True)