【机器学习05】kNN小结:解决鸢尾花和手写数字识别分类

kNN 解决分类问题的套路。

摘要:运用 kNN 解决鸢尾花和手写数字识别分类问题,熟悉 Sklearn 的一般套路。

今天我们以两个常见的数据集鸢尾花手写数字识别为例,练习 Sklearn 使用 kNN 算法解决机器学习分类问题,作为对之前四篇文章的小结。

练习完这两个案例,相信会对 kNN 算法有一个比较全面的理解,同时能学会 Sklearn 处理机器学习的一些固定套路,为下一步继续 kNN 算法做准备。

预测鸢尾花数据集分类

鸢(yuan)尾花是 20 世纪 30 年代的一个经典数据集。该数据集包括三种花共 150 个样本,每个样本有 4 个数值型特征,分别是花萼长度(cm)、花萼宽度(cm)、花瓣长度(cm)、花瓣宽度(cm)。

取数据集的前 5 行预览一下:

sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)Class
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20

现在的任务是:随机给一些样本,要判定它们分别属于哪一种花。和葡萄酒数据集的问题很相似。

所以我们同样可以用 kNN 算法来找到答案。思路很简单,加载数据集并划分训练集和测试集,在训练集上训练模型,然后把测试集应用到模型中,预测样本分别属于哪类花,最后计算分类的准确率。

mark

最后分类准确度达到了 97.8%,45 个测试花的样本仅预测错了 1 样本,准确率相当高。这还是在我们没有对模型做任何调优的情况下得到的。可见 kNN 算法的确是种效果很好的算法

接着,再来尝试一个相对大型点的分类数据集手写数字识别数据集,看看 kNN 算法性能如何。

手写数字识别数据集预测

mark

这个数据集来源 1998年 的一个手写数字实验。包含 1797 个样本,每个样本有 64 个特征(由 8 * 8 构成的 64 个数字像素点)每个样本的标签分别是 0-9 自然数中的一个。

现在的任务是:随机给一些数字样本,判定它们是哪个数字。这个任务 kNN 模型也能很好地完成,过程和刚才的鸢尾花一样,我们就直接贴代码:

plt.imshow 是一个图像处理函数,详细使用可以参考:plt.imshow 教程

%%time 是 jupyter book 的一个魔法命令,可以计算单元格执行的运算时间

箭头处对比了我们手写的模型和 Sklearn 中的模型的运行时间:8.74s 和 105ms,我们的算法慢了 80 倍差距非常大,主要是我们使用的是最简单粗暴的算法,而 Sklearn 用了更优化的方法。

这里也反映了 kNN 算法的运行时间会随着数据集维度增大而增大,所以得优化 kNN 算法才可以,比如使用 kd 树、球树等,我们放到后面再讲。

到这儿,我们练习了三个实例,相信对 kNN 算法有了比较深的认识了,不过现在又有一个新的问题。

我们在建立模型时,一直默认 k 参数(选择近邻样本点个数)为 3,这个参数对模型分类结果影响很大,它还可以是很多其他值,那是不是选 3 得到的模型就一定是最好的呢?答案显然不会这么绝对。另外,在第一篇文章中,我们计算模型距离时默认使用的是欧拉距离,而欧拉距离是不是对任何数据集都是最好的计算方法呢,答案显然也不绝对。

所以这里就涉及到了 kNN 算法的超参数问题,上面的 k 和距离都是超参数。不同的超参数会得到不同的 kNN 模型,为了得到更好的 kNN 模型,我们就需要好好了解一下超参数,下一篇文章就来介绍它。

本文的 jupyter notebook 代码,可以在我公众号:「高级农民工」后台回复「kNN5」得到,加油!

你一打赏,我就写得更来劲了
0%