demo_iris_classification / pages /03model_train.py
test2023h5's picture
Upload 15 files
30f79a7 verified
### 模型训练
import simplestart as ss
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
ss.md('''
## 模型训练
''')
#加载数据,并划分样本数据
data = datasets.load_iris()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
#页面会话变量
ss.session["acc"] = ""
ss.session["code"] = 0
#训练函数
def train(event):
clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
acc = round(acc, 2)
ss.session["acc"] = acc #将结果赋值给页面会话变量,相应页面显示值会自动响应
ss.md('''
#### 模型训练的主要步骤:
首先,从数据集中加载 Iris 数据(包括特征和标签),并将这些数据划分为训练集和测试集,其中 80% 用于训练,20% 用于测试。接着,定义了一个训练函数,该函数使用 K-Nearest Neighbors(KNN)分类器进行训练,评估模型的预测精度,并将结果保存在一个页面会话变量中,以便在网页上显示。
###
在网页上,显示了一个训练按钮。当用户点击这个按钮时,训练函数会被触发,模型会在后台进行训练并计算测试集的预测精度。训练完成后,精度结果会更新到页面中,并以“Accuracy = @acc”格式展示给用户,其中 @acc 是训练过程中计算得到的预测精度值。
###
训练和测试速度特别快可能是因为 Iris 数据集非常小,只有 150 个样本和 4 个特征。此外,K-Nearest Neighbors(KNN)是一种简单且高效的算法,特别是在小数据集上表现较好,因此训练和测试过程迅速完成。
###
---
''')
ss.write(f'测试集的预测精度 Accuracy =', "@acc")
ss.button("Train", onclick = train)
#ui
ss.md("---")
def conditioner(event):
return ss.session["code"] == 1
def checkcode(event):
ss.session["code"] = 1
def hidecode(event):
ss.session["code"] = 0
ss.button("查看代码", onclick = checkcode)
ss.button("隐藏代码", onclick = hidecode)
with ss.when(conditioner):
ss.md('''
```python
import simplestart as ss
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
#加载数据,并划分样本数据
data = datasets.load_iris()
X = data.data
y = data.target
ss.write(X.shape, y.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
#页面会话变量
ss.session["acc"] = ""
#训练函数
def train(event):
clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
acc = round(acc, 2)
ss.session["acc"] = acc #将结果赋值给页面会话变量,相应页面显示值会自动响应
#显示在测试集上模型的准确率
ss.write(f'测试集的预测精度 Accuracy =', "\@acc")
ss.button("Train", onclick = train)
```
''')
ss.md("---")
ss.md('''
::: tip
### KNN的优点:
简洁、易于理解、易于实现、无须估计参数,无须训练;
适合对稀有事件进行分类;
特别适用于多分类问题(Multi-label,对象具有多个类别标签)
:::
''')
ss.md('''
更多KNN介绍,请参考
[KNN分类算法介绍,用KNN分类鸢尾花数据集(iris)](https://blog.csdn.net/weixin_51756038/article/details/130096706)
''')