醋醋百科网

Good Luck To You!

动手学机器学习(三)k 近邻算法

k 近邻(k-nearest neighbor, KNN)算法:相似的数据往往拥有相同的类别,同一种类的数据之间特征更为相似,而不同种类的数据之间特征差别更大。

目标是判断样本的类别。KNN 首先会观察与该样本点距离最近的个样本,统计这些样本所属的类别。然后,将当前样本归到出现次数最多的类中。

KNN 的基本思路是让当前样本的分类服从邻居中的多数分类。

3.1 用 KNN 算法完成分类任务

在 MNIST 数据集上应用 KNN 算法,完成分类任务。MNIST 是手写数字数据集,其中包含了很多手写数字 0~9 的黑白图像,每张图像都由 28*28 个像素点组成。读者可以在 MNIST 的官方网站上得到更多数据集的信息。读入后,每个像素点用 1 或 0 表示,1 代表黑色像素,属于图像背景;0 代表白色像素,属于手写数字。我们的任务是用 KNN 对不同的手写数字进行分类。为了更清晰地展示数据集的内容,下面先将前两个数据点转成黑白图像显示出来。此外,把每个数据集都按 8:2 的比例随机划分成训练集(training set)和测试集(test set)。

import matplotlib.pyplot as plt
import numpy as np
import os

# 读入 mnist 数据集
# 使用 numpy 的 loadtxt 函数从文件'mnist_x'中读取数据,数据以空格分隔
m_x = np.loadtxt('mnist_x', delimiter=' ')
# 从文件'mnist_y'中读取数据
m_y = np.loadtxt('mnist_y')

# 数据集可视化
# 选取 m_x 中的第一个样本,并将其转换为整数型数组,然后重塑为 28x28 的矩阵,以表示一张 28x28 的图像
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
# 创建一个新的图形窗口
plt.figure()
# 使用 imshow 函数将图像数据显示出来,颜色映射为灰度图
plt.imshow(data, cmap='gray')

# 将数据集分为训练集和测试集
# 设定训练集占总数据集的比例
ratio = 0.8
# 根据比例计算分割点的索引位置
split = int(len(m_x) * ratio)

# 打乱数据
# 设置随机数种子为 0,以保证结果的可重复性
np.random.seed(0)
# 生成一个从 0 到 m_x 长度的随机排列的索引数组
idx = np.random.permutation(np.arange(len(m_x)))
# 根据随机索引重新排列 m_x
m_x = m_x[idx]
# 根据随机索引重新排列 m_y
m_y = m_y[idx]

# 分割数据集,将前 split 个样本作为训练集
x_train, x_test = m_x[:split], m_x[split:]
# 分割标签集,将前 split 个样本的标签作为训练集标签
y_train, y_test = m_y[:split], m_y[split:]

def distance(a, b):
    # 计算 a 和 b 的差
    # np.square 函数对差进行平方操作
    # np.sum 函数对平方后的差进行求和
    # np.sqrt 函数对求和结果进行开平方,得到欧几里得距离
    return np.sqrt(np.sum(np.square(a - b)))

将 KNN 算法定义成类,其初始化参数是K和类别的数量。

class KNN:
    def __init__(self, k, label_num):
        # 初始化 KNN 类的实例,接收 k 和类别数量 label_num 作为参数
        # 存储 k 值,用于后续 K 近邻的计算
        self.k = k
        # 存储类别数量,可能用于类别统计等操作
        self.label_num = label_num  

    def fit(self, x_train, y_train):
        # 将传入的训练数据 x_train 和训练标签 y_train 存储在类的属性中,以便后续使用
        self.x_train = x_train
        self.y_train = y_train

    def get_knn_indices(self, x):
        # 获取距离目标样本点 x 最近的 K 个样本点的标签
        # 计算已知样本 self.x_train 中的每个样本与目标样本 x 的距离
        # 使用 map 函数和 lambda 表达式,将 distance 函数应用于每个样本
        dis = list(map(lambda a: distance(a, x), self.x_train))
        # 对距离列表 dis 进行排序,并返回排序后的索引
        # np.argsort 函数返回的是排序后元素在原数组中的索引
        knn_indices = np.argsort(dis)
        # 取距离最近的 K 个样本的索引
        knn_indices = knn_indices[:self.k]
        return knn_indices

    def get_label(self, x):
        # 对 KNN 方法的具体实现,观察 K 个近邻并使用 np.argmax 获取其中数量最多的类别
        # 获取距离样本 x 最近的 K 个样本的索引
        knn_indices = self.get_knn_indices(x)
        # 初始化一个长度为 self.label_num 的零数组,用于存储类别统计结果
        label_statistic = np.zeros(shape=[self.label_num])
        # 遍历 K 个最近邻的索引
        for index in knn_indices:
            # 获取对应样本的标签,并将其转换为整数
            label = int(self.y_train[index])
            # 对相应类别的统计结果加 1
            label_statistic[label] += 1
        # 返回统计结果中数量最多的类别,使用 np.argmax 找到最大值的索引
        return np.argmax(label_statistic)

    def predict(self, x_test):
        # 预测样本 x_test 的类别
        # 初始化一个数组,用于存储预测的测试集标签,长度为测试集样本的数量,数据类型为整数
        predicted_test_labels = np.zeros(shape=[len(x_test)], dtype=int)
        # 遍历测试集中的每个样本
        for i, x in enumerate(x_test):
            # 对每个样本调用 get_label 函数进行预测,并存储预测结果
            predicted_test_labels[i] = self.get_label(x)
        return predicted_test_labels

在测试集上观察算法的效果,并对不同K的的取值进行测试

# 循环遍历 k 的值,从 1 到 9
for k in range(1, 10):
    # 创建 KNN 类的实例,k 的值从 1 到 9,类别数量为 10
    knn = KNN(k, label_num=10)
    # 调用 fit 方法,将训练数据 x_train 和 y_train 传递给 KNN 类进行训练
    knn.fit(x_train, y_train)
    # 调用 predict 方法,对测试数据 x_test 进行预测,得到预测的标签
    predicted_labels = knn.predict(x_test)
    # 计算预测准确率,通过比较预测标签和真实标签 y_test 相等的元素的比例
    accuracy = np.mean(predicted_labels == y_test)
    # 打印出 k 的值和对应的预测准确率,保留一位小数
    print(f'K的取值为 {k}, 预测准确率为 {accuracy * 100:.1f}%')

K的取值为 1, 预测准确率为 88.5%

K的取值为 2, 预测准确率为 88.0%

K的取值为 3, 预测准确率为 87.5%

K的取值为 4, 预测准确率为 87.5%

K的取值为 5, 预测准确率为 88.5%

K的取值为 6, 预测准确率为 88.5%

K的取值为 7, 预测准确率为 88.0%

K的取值为 8, 预测准确率为 87.0%

K的取值为 9, 预测准确率为 87.0%

参考文献

【1】张伟楠, 赵寒烨, 俞勇. (2023). 动手学机器学习[M]. 北京:人民邮电出版社.

【2】在线教程、PPT、视频和源代码
https://hml.boyuai.com/books

【3】软件源代码

https://github.com/boyu-ai/Hands-on-ML

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言