|
|
|
|
|
import simplestart as ss |
|
|
|
from sklearn import datasets |
|
from sklearn.neighbors import KNeighborsClassifier |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import accuracy_score |
|
|
|
|
|
ss.md(''' |
|
## 模型训练 |
|
''') |
|
|
|
|
|
|
|
data = datasets.load_iris() |
|
X = data.data |
|
y = data.target |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) |
|
|
|
|
|
ss.session["acc"] = "" |
|
ss.session["code"] = 0 |
|
|
|
|
|
def train(event): |
|
clf = KNeighborsClassifier(n_neighbors=3) |
|
clf.fit(X_train,y_train) |
|
|
|
y_pred = clf.predict(X_test) |
|
|
|
acc = accuracy_score(y_test, y_pred) |
|
acc = round(acc, 2) |
|
|
|
ss.session["acc"] = acc |
|
|
|
ss.md(''' |
|
#### 模型训练的主要步骤: |
|
首先,从数据集中加载 Iris 数据(包括特征和标签),并将这些数据划分为训练集和测试集,其中 80% 用于训练,20% 用于测试。接着,定义了一个训练函数,该函数使用 K-Nearest Neighbors(KNN)分类器进行训练,评估模型的预测精度,并将结果保存在一个页面会话变量中,以便在网页上显示。 |
|
### |
|
在网页上,显示了一个训练按钮。当用户点击这个按钮时,训练函数会被触发,模型会在后台进行训练并计算测试集的预测精度。训练完成后,精度结果会更新到页面中,并以“Accuracy = @acc”格式展示给用户,其中 @acc 是训练过程中计算得到的预测精度值。 |
|
### |
|
训练和测试速度特别快可能是因为 Iris 数据集非常小,只有 150 个样本和 4 个特征。此外,K-Nearest Neighbors(KNN)是一种简单且高效的算法,特别是在小数据集上表现较好,因此训练和测试过程迅速完成。 |
|
### |
|
--- |
|
''') |
|
ss.write(f'测试集的预测精度 Accuracy =', "@acc") |
|
|
|
ss.button("Train", onclick = train) |
|
|
|
|
|
|
|
ss.md("---") |
|
|
|
def conditioner(event): |
|
return ss.session["code"] == 1 |
|
|
|
def checkcode(event): |
|
ss.session["code"] = 1 |
|
|
|
def hidecode(event): |
|
ss.session["code"] = 0 |
|
|
|
ss.button("查看代码", onclick = checkcode) |
|
ss.button("隐藏代码", onclick = hidecode) |
|
|
|
with ss.when(conditioner): |
|
ss.md(''' |
|
```python |
|
import simplestart as ss |
|
|
|
from sklearn import datasets |
|
from sklearn.neighbors import KNeighborsClassifier |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import accuracy_score |
|
|
|
#加载数据,并划分样本数据 |
|
data = datasets.load_iris() |
|
X = data.data |
|
y = data.target |
|
ss.write(X.shape, y.shape) |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) |
|
|
|
#页面会话变量 |
|
ss.session["acc"] = "" |
|
|
|
#训练函数 |
|
def train(event): |
|
clf = KNeighborsClassifier(n_neighbors=3) |
|
clf.fit(X_train,y_train) |
|
|
|
y_pred = clf.predict(X_test) |
|
|
|
acc = accuracy_score(y_test, y_pred) |
|
acc = round(acc, 2) |
|
|
|
ss.session["acc"] = acc #将结果赋值给页面会话变量,相应页面显示值会自动响应 |
|
|
|
#显示在测试集上模型的准确率 |
|
ss.write(f'测试集的预测精度 Accuracy =', "\@acc") |
|
|
|
ss.button("Train", onclick = train) |
|
``` |
|
''') |
|
|
|
|
|
ss.md("---") |
|
|
|
|
|
ss.md(''' |
|
::: tip |
|
### KNN的优点: |
|
简洁、易于理解、易于实现、无须估计参数,无须训练; |
|
适合对稀有事件进行分类; |
|
特别适用于多分类问题(Multi-label,对象具有多个类别标签) |
|
::: |
|
''') |
|
|
|
ss.md(''' |
|
更多KNN介绍,请参考 |
|
[KNN分类算法介绍,用KNN分类鸢尾花数据集(iris)](https://blog.csdn.net/weixin_51756038/article/details/130096706) |
|
''') |