Welcome to R Square

K 最近邻算法(KNN)

楚新元 / 2023-11-01


理解近邻分类

你知道蛋白质、蔬菜和水果是怎么分类的吗?生活中我们发现既不脆也不甜的是蛋白质,脆而不甜的是蔬菜,而水果往往是甜的,有可能脆也有可能不脆。基于以上生活经验(人以群分,物以类聚),那么你知道西红柿是水果还是蔬菜呢?首先我们来看下面一组数据。

食物甜度脆度食物类型
葡萄85水果
四季豆37蔬菜
坚果36蛋白质
橙子73水果

现在如果我们知道西红柿的甜度为 6,脆度为 4,如果我们把这些数据放在横轴为甜度,纵轴为脆度的二维平面图上,我们很容易计算出西红柿与其他四种食物之间的直线距离。例如西红柿和四季豆之间的距离为:

$$d(西红柿, 四季豆) = \sqrt{(6-3)^2 + (4-7)^2} = 4.2$$

根据以上算法,我们分别计算了西红柿和葡萄、四季豆、坚果、橙子之间的距离,分别是2.2、4.2、3.6、1.4。

我们发现西红柿和橙子之间的距离最短,那么我们据此认为西红柿是一种水果。这里其实是只选了一个最近的“邻居”,即 \(k=1\),是一个 1NN 分类。

如果我们使用 \(k=3\) 的 KNN 算法。那么它会在三个最近邻居即橙子、葡萄和坚果之间进行投票表决。因为这个里面有两票归为水果(2/3 的票数),所以西红柿再次归为水果。

定义最近邻

在 KNN 算法中,常用的距离有三种,分别为曼哈顿距离、欧几里得距离和闵可夫斯基距离。在寻找最近邻的过程中,一般使用欧几里得距离来确定最近的邻居。

那么 k 究竟选择多少合适呢?

可以使用几种策略来选择 k 参数,第一个快速直接的解决方法是将 k 设置为训练实例数量的平方根;另一种方法是使用验证集上的优化工具选择 k,在这种情况下,将训练集进一步划分为训练集和验证集,并选择 k,以便使用训练数据将验证数据的预测准确性最大化,k 应该最小化诸如 MAPE 之类的预测精度统计量(Hyndman and Koehler 2006),应该注意的是,这种优化策略非常耗时;Martínez 等(2019)探索的第三个策略是使用多个具有不同 k 值的 KNN 模型,每个 KNN 模型都会生成其预测,并对不同模型的预测值求平均,以生成最终预测值(Martínez et al. 2019),此策略基于模型组合在时间序列预测中的成功实践(Hibon and Evgeniou 2005),这种方式,避免了使用费时的优化工具,并且预测也不基于唯一的 k 值。

待预测新实例目标值计算方法选择

一旦我们根据特征确定了对 k 个最近邻居,我们就可以汇总这 k 个最近邻居的目标得到待预测新实例的目标值。默认情况下对 k 个最近邻居的目标取平均值。但是,我们除了取 k 个最近邻居的目标取平均值外,还可以选择取去掉最大最小值后的平均数、中位数和加权平均数。如果选择加权平均数算法,其核心思想是将更多的权重分配给更近的邻居。

注意,如果根据加权平均法预测新实例的目标,则 k 参数的选择就不太重要,因为邻居在远离新实例时会获得一个较小的权重。在这种情况下,如果k的取值足够大,正好等于训练实例的数目时,可以视为广义回归神经网络来处理(Weizhong Yan 2012)。

KNN 算法对数据的要求

如果我们在数据集中加入另外一个特征,比如食物的辛辣度,辛辣度的取值在 0~100 多万,而甜度和脆度的取值在 1~10 之间,所以尺度的差异,导致辛辣对距离函数的影响远远超过了甜度和脆度,如果不对数据进行调整,那么我们可以预见,距离度量和辛辣度有很大的关系,而脆度、甜度的影响几乎可以忽略不计。

解决的方法便是对原数据进行标准化处理,使各个特征的值都落在 0~1 范围内,或者使各个特征在量上具有可比性。常用的方法有两种:

normalize_mm = \(x) {
  return ((x - min(x)) / (max(x) - min(x)))
}
normalize_z = function(x) {
  return ((x - mean(x)) / sd(x))
}

注意:计算非数值型数据的距离,需要将原数据先转化为数值型数据。一种典型的解决方案使利用虚拟变量编码。例如 1 表示男性,0 表示女性。

下面是一个 KNN 实战案例,进一步学习 KNN 算法的应用。

第一步收集数据

案例来自 R 语言实战第二版(Kabacoff 2011)。文中数据来源为威斯康星州乳腺癌数据集,本数据包含 699 个样本,11 个变量。

# 对原始数据添加变量名称
data(biopsy, package = "MASS")
names(biopsy) = c(
  "ID", "clumpThickness", "sizeUniformity", "shapeUniformity",
  "maginalAdhesion", "singleEpithelialCellSize", "bareNuclei",
  "blandChromatin", "normalNucleoli", "mitosis", "class"
)

数据集中的变量说明:

第二步探索和准备数据

# 清洗数据
biopsy |> 
  subset(select = -ID) |>  # 去掉 ID 列,此列属于无关变量
  na.omit() -> data  # 缺失值占比很少,此处直接删除

# 对除 class 列的数据进行标准化处理
df = as.data.frame(lapply(data[1:9], normalize_mm))

# 确定样本比例和样本量
split_size = 0.7
sample_size = floor(nrow(data) * split_size)

# 创建训练数据集和验证数据集
set.seed(1234)
train_ind = sample(nrow(data), size = sample_size)
train = df[train_ind, ]
validate = df[-train_ind, ]
train_labels = data[train_ind, 10]  # 训练数据集诊断结果
validate_labels = data[-train_ind, 10]  # 验证数据集诊断结果

# 对训练数据和验证数据做初步统计
library(dplyr)
train |> 
  cbind(train_labels) |> 
  rename(class = train_labels) |> 
  group_by(class) |> 
  summarise(
    total = n()
  ) -> train_stat

validate |> 
  cbind(validate_labels) |> 
  rename(class = validate_labels) |> 
  group_by(class) |> 
  summarise(
    total = n()
  ) -> validate_stat

第一个变量 ID 不纳入数据分析,最后一个变量 class 即输出变量。

对于每一个样本来说,另外九个变量是与判别恶性肿瘤相关的细胞特征,任一变量都不能单独作 为判别良性或恶性的标准,建模的目的是找到九个细胞特征的某种组合,从而实现对恶性肿瘤的 准确预测。

剔除缺失值,并随机分出训练集和验证集,其中 训练集中包含 478 个样本单元 (占 70%), 其中良性样本单元 302 个, 恶性样本单元 176 个; 验证集中包含 205 个样本单元 (占 30%), 其中良性 142 个, 恶性 63 个。

第三步基于数据训练模型

因为训练的样本有 478 个,开根后是 22,因此此处 k 取 22。

library(class)
knn_pred = knn(
  train = train,
  test = validate,
  cl = train_labels,
  k = 22
)

第四步评估模型的性能

library(gmodels)
CrossTable(
  x = validate_labels,
  y = knn_pred,
  dnn = c("Actual", "Predicted"),
  prop.chisq = FALSE
)
#> 
#>  
#>    Cell Contents
#> |-------------------------|
#> |                       N |
#> |           N / Row Total |
#> |           N / Col Total |
#> |         N / Table Total |
#> |-------------------------|
#> 
#>  
#> Total Observations in Table:  205 
#> 
#>  
#>              | Predicted 
#>       Actual |    benign | malignant | Row Total | 
#> -------------|-----------|-----------|-----------|
#>       benign |       140 |         2 |       142 | 
#>              |     0.986 |     0.014 |     0.693 | 
#>              |     0.979 |     0.032 |           | 
#>              |     0.683 |     0.010 |           | 
#> -------------|-----------|-----------|-----------|
#>    malignant |         3 |        60 |        63 | 
#>              |     0.048 |     0.952 |     0.307 | 
#>              |     0.021 |     0.968 |           | 
#>              |     0.015 |     0.293 |           | 
#> -------------|-----------|-----------|-----------|
#> Column Total |       143 |        62 |       205 | 
#>              |     0.698 |     0.302 |           | 
#> -------------|-----------|-----------|-----------|
#> 
#> 

左上角代表真阴性,右下角代表真阳性。预测的准确率为 (140+60)/205*100%=97.56%。同时我们也发现位于左下角的 3 个样本,实际为恶性,但是却被 KNN 错误地归为良性,即假阴性;右上角 2 个样本,实际为良性,却被 KNN 错误地归为恶性,即假阳性。但是预测的准确率还是比较高的,模型令人满意。

第五步提高模型的性能

这里我们可以尝试两种简单的改变,一是数据标准化处理时可以考虑采用 z-分数标准化,二是尝试几个不同的 k 值。需要注意的是,过分的追求预测的精度,可能导致过拟合,加大了拟合噪音的可能,从而使泛化能力变弱。

在确定 k 值方面,caret 包又可以大显身手了。

library(caret)
set.seed(1234) # 设置随机数种子,方便重复性研究
grid = expand.grid(.k = seq(2, 200, by = 1))
control = trainControl(method = "cv")

train |> 
  cbind(train_labels) |> 
  rename(class = train_labels) -> train_mm

knn_train = train(
  class ~ .,
  data = train_mm,
  method = "knn",
  trControl = control,
  tuneGrid = grid
)

knn_train
#> k-Nearest Neighbors 
#> 
#> 478 samples
#>   9 predictor
#>   2 classes: 'benign', 'malignant' 
#> 
#> No pre-processing
#> Resampling: Cross-Validated (10 fold) 
#> Summary of sample sizes: 430, 430, 429, 430, 430, 430, ... 
#> Resampling results across tuning parameters:
#> 
#>   k    Accuracy   Kappa    
#>     2  0.9561134  0.9041020
#>     3  0.9686595  0.9325849
#>     4  0.9644485  0.9234264
#>     5  0.9623652  0.9184103
#>     6  0.9623652  0.9187275
#>     7  0.9602819  0.9143319
#>     8  0.9603262  0.9145271
#>     9  0.9602819  0.9142288
#>    10  0.9581985  0.9097377
#>    11  0.9581985  0.9097377
#>    12  0.9581985  0.9095423
#>    13  0.9602819  0.9141333
#>    14  0.9581985  0.9097377
#>    15  0.9581985  0.9095423
#>    16  0.9581985  0.9095423
#>    17  0.9602819  0.9141333
#>    18  0.9561152  0.9048436
#>    19  0.9561152  0.9048436
#>    20  0.9581985  0.9094346
#>    21  0.9581985  0.9094346
#>    22  0.9561152  0.9048436
#>    23  0.9581985  0.9094346
#>    24  0.9602819  0.9139290
#>    25  0.9581985  0.9093380
#>    26  0.9581985  0.9093380
#>    27  0.9581985  0.9093380
#>    28  0.9581985  0.9093380
#>    29  0.9581985  0.9093380
#>    30  0.9581985  0.9093380
#>    31  0.9561152  0.9048436
#>    32  0.9561152  0.9048436
#>    33  0.9581985  0.9093380
#>    34  0.9561152  0.9048436
#>    35  0.9561152  0.9048436
#>    36  0.9561152  0.9048436
#>    37  0.9561152  0.9048436
#>    38  0.9561152  0.9048436
#>    39  0.9561152  0.9048436
#>    40  0.9561152  0.9048436
#>    41  0.9561152  0.9048436
#>    42  0.9561152  0.9048436
#>    43  0.9561152  0.9048436
#>    44  0.9561152  0.9048436
#>    45  0.9561152  0.9048436
#>    46  0.9561152  0.9048436
#>    47  0.9561152  0.9048436
#>    48  0.9561152  0.9048436
#>    49  0.9561152  0.9048436
#>    50  0.9561152  0.9048436
#>    51  0.9561152  0.9048436
#>    52  0.9561152  0.9048436
#>    53  0.9561152  0.9048436
#>    54  0.9561152  0.9048436
#>    55  0.9561152  0.9048436
#>    56  0.9561152  0.9048436
#>    57  0.9561152  0.9048436
#>    58  0.9561152  0.9048436
#>    59  0.9561152  0.9048436
#>    60  0.9561152  0.9048436
#>    61  0.9561152  0.9048436
#>    62  0.9561152  0.9048436
#>    63  0.9561152  0.9048436
#>    64  0.9561152  0.9048436
#>    65  0.9561152  0.9048436
#>    66  0.9561152  0.9048436
#>    67  0.9561152  0.9048436
#>    68  0.9561152  0.9048436
#>    69  0.9539876  0.9000542
#>    70  0.9539876  0.9000542
#>    71  0.9539876  0.9000542
#>    72  0.9539876  0.9000542
#>    73  0.9539876  0.9000542
#>    74  0.9519042  0.8954388
#>    75  0.9519042  0.8954388
#>    76  0.9519042  0.8954388
#>    77  0.9519042  0.8954388
#>    78  0.9519042  0.8954388
#>    79  0.9519042  0.8954388
#>    80  0.9519042  0.8954388
#>    81  0.9498634  0.8909976
#>    82  0.9498634  0.8909976
#>    83  0.9498634  0.8909976
#>    84  0.9498634  0.8909976
#>    85  0.9498634  0.8909976
#>    86  0.9519042  0.8954388
#>    87  0.9498634  0.8909976
#>    88  0.9498634  0.8909976
#>    89  0.9498634  0.8909976
#>    90  0.9498634  0.8909976
#>    91  0.9498634  0.8909976
#>    92  0.9498634  0.8909976
#>    93  0.9498634  0.8909976
#>    94  0.9498634  0.8909976
#>    95  0.9498634  0.8909976
#>    96  0.9498634  0.8909976
#>    97  0.9498634  0.8909976
#>    98  0.9498634  0.8909976
#>    99  0.9498634  0.8909976
#>   100  0.9498634  0.8909976
#>   101  0.9498634  0.8909976
#>   102  0.9498634  0.8909976
#>   103  0.9498634  0.8905509
#>   104  0.9498634  0.8905509
#>   105  0.9498634  0.8905509
#>   106  0.9498634  0.8905509
#>   107  0.9498634  0.8905509
#>   108  0.9498634  0.8905509
#>   109  0.9498634  0.8905509
#>   110  0.9498634  0.8905509
#>   111  0.9477801  0.8857393
#>   112  0.9477801  0.8857393
#>   113  0.9477801  0.8857393
#>   114  0.9477801  0.8857393
#>   115  0.9435691  0.8761288
#>   116  0.9456967  0.8810452
#>   117  0.9435691  0.8761288
#>   118  0.9435691  0.8761288
#>   119  0.9435691  0.8761288
#>   120  0.9435691  0.8761288
#>   121  0.9414857  0.8712563
#>   122  0.9414857  0.8712563
#>   123  0.9414857  0.8712563
#>   124  0.9414857  0.8712563
#>   125  0.9394024  0.8666598
#>   126  0.9394024  0.8666598
#>   127  0.9394024  0.8666598
#>   128  0.9394024  0.8666598
#>   129  0.9394024  0.8666598
#>   130  0.9394024  0.8666598
#>   131  0.9394024  0.8666598
#>   132  0.9394024  0.8666598
#>   133  0.9373191  0.8617350
#>   134  0.9373191  0.8617350
#>   135  0.9373191  0.8617350
#>   136  0.9373191  0.8617350
#>   137  0.9373191  0.8617350
#>   138  0.9373191  0.8617350
#>   139  0.9373191  0.8617350
#>   140  0.9373191  0.8617350
#>   141  0.9373191  0.8617350
#>   142  0.9373191  0.8617350
#>   143  0.9373191  0.8617350
#>   144  0.9373191  0.8617350
#>   145  0.9373191  0.8617350
#>   146  0.9373191  0.8617350
#>   147  0.9373191  0.8617350
#>   148  0.9373191  0.8617350
#>   149  0.9373191  0.8617350
#>   150  0.9352357  0.8570328
#>   151  0.9352357  0.8570328
#>   152  0.9352357  0.8570328
#>   153  0.9352357  0.8570328
#>   154  0.9352357  0.8570328
#>   155  0.9352357  0.8570328
#>   156  0.9352357  0.8570328
#>   157  0.9352357  0.8570328
#>   158  0.9352357  0.8570328
#>   159  0.9352357  0.8570328
#>   160  0.9352357  0.8570328
#>   161  0.9352357  0.8570328
#>   162  0.9331524  0.8519908
#>   163  0.9331524  0.8519908
#>   164  0.9331524  0.8519908
#>   165  0.9331524  0.8519908
#>   166  0.9331524  0.8519908
#>   167  0.9331524  0.8519908
#>   168  0.9331524  0.8519908
#>   169  0.9331524  0.8519908
#>   170  0.9331524  0.8519908
#>   171  0.9331524  0.8519908
#>   172  0.9331524  0.8519908
#>   173  0.9331524  0.8519908
#>   174  0.9310248  0.8470745
#>   175  0.9310248  0.8472014
#>   176  0.9288971  0.8422851
#>   177  0.9310248  0.8472014
#>   178  0.9310248  0.8472014
#>   179  0.9310248  0.8472014
#>   180  0.9288971  0.8422851
#>   181  0.9288971  0.8422851
#>   182  0.9310248  0.8472014
#>   183  0.9289839  0.8426555
#>   184  0.9268563  0.8377391
#>   185  0.9268563  0.8377391
#>   186  0.9268563  0.8377391
#>   187  0.9268563  0.8377391
#>   188  0.9247286  0.8326907
#>   189  0.9226010  0.8276423
#>   190  0.9226010  0.8276423
#>   191  0.9226010  0.8276423
#>   192  0.9226010  0.8276423
#>   193  0.9226010  0.8276423
#>   194  0.9205176  0.8224788
#>   195  0.9205176  0.8224788
#>   196  0.9205176  0.8224788
#>   197  0.9205176  0.8224788
#>   198  0.9205176  0.8224788
#>   199  0.9205176  0.8224788
#>   200  0.9205176  0.8224788
#> 
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was k = 3.

报告显示当 \(k=3\) 时模型最优3,此时模型的准确率最高,为 96.87%。其中:Kappa 统计量(用于测量两个分类器对观测值分类的一致性)对正确率进行了修正,去除了仅靠偶然性(或随机性)获得正确分类的因素。

下面我们利用 \(k=3\) 重新训练模型:

knn_pred_new = knn(
  train = train,
  test = validate,
  cl = train_labels,
  k = 3
) 
CrossTable(
  x = validate_labels,
  y = knn_pred_new,
  dnn = c("Actual", "Predicted"),
  prop.chisq = FALSE
)
#> 
#>  
#>    Cell Contents
#> |-------------------------|
#> |                       N |
#> |           N / Row Total |
#> |           N / Col Total |
#> |         N / Table Total |
#> |-------------------------|
#> 
#>  
#> Total Observations in Table:  205 
#> 
#>  
#>              | Predicted 
#>       Actual |    benign | malignant | Row Total | 
#> -------------|-----------|-----------|-----------|
#>       benign |       139 |         3 |       142 | 
#>              |     0.979 |     0.021 |     0.693 | 
#>              |     0.993 |     0.046 |           | 
#>              |     0.678 |     0.015 |           | 
#> -------------|-----------|-----------|-----------|
#>    malignant |         1 |        62 |        63 | 
#>              |     0.016 |     0.984 |     0.307 | 
#>              |     0.007 |     0.954 |           | 
#>              |     0.005 |     0.302 |           | 
#> -------------|-----------|-----------|-----------|
#> Column Total |       140 |        65 |       205 | 
#>              |     0.683 |     0.317 |           | 
#> -------------|-----------|-----------|-----------|
#> 
#> 

我们比较两次结果,我们发现假阴性减少了 2 个,假阳性增加 1 个,总体上预测的精度有所提高,为 (139+62)/205*100%=98.05%。

最后需要指出的是,还有其他方法可以对距离进行加权,kknn 包提供了 10 中不同的加权方式,有兴趣可以尝试。

为了保证结果的可重现,我把系统环境信息提供如下:

xfun::session_info(c("class", "gmodels", "caret"))
#> R version 4.5.1 (2025-06-13)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Arch Linux
#> 
#> Locale:
#>   LC_CTYPE=zh_CN.UTF-8       LC_NUMERIC=C              
#>   LC_TIME=zh_CN.UTF-8        LC_COLLATE=zh_CN.UTF-8    
#>   LC_MONETARY=zh_CN.UTF-8    LC_MESSAGES=zh_CN.UTF-8   
#>   LC_PAPER=zh_CN.UTF-8       LC_NAME=C                 
#>   LC_ADDRESS=C               LC_TELEPHONE=C            
#>   LC_MEASUREMENT=zh_CN.UTF-8 LC_IDENTIFICATION=C       
#> 
#> Package version:
#>   caret_7.0-1          class_7.3-23         cli_3.6.5           
#>   clock_0.7.3          codetools_0.2.20     compiler_4.5.1      
#>   cpp11_0.5.2          data.table_1.17.8    diagram_1.6.5       
#>   digest_0.6.37        dplyr_1.1.4          e1071_1.7.16        
#>   farver_2.1.2         foreach_1.5.2        future_1.67.0       
#>   future.apply_1.20.0  gdata_3.0.1          generics_0.1.4      
#>   ggplot2_4.0.0        globals_0.18.0       glue_1.8.0          
#>   gmodels_2.19.1       gower_1.0.2          graphics_4.5.1      
#>   grDevices_4.5.1      grid_4.5.1           gtable_0.3.6        
#>   gtools_3.9.5         hardhat_1.4.2        ipred_0.9.15        
#>   isoband_0.2.7        iterators_1.0.14     KernSmooth_2.23.26  
#>   labeling_0.4.3       lattice_0.22.7       lava_1.8.1          
#>   lifecycle_1.0.4      listenv_0.9.1        lubridate_1.9.4     
#>   magrittr_2.0.4       MASS_7.3.65          Matrix_1.7.4        
#>   methods_4.5.1        ModelMetrics_1.2.2.2 nlme_3.1.168        
#>   nnet_7.3.20          numDeriv_2016.8.1.1  parallel_4.5.1      
#>   parallelly_1.45.1    pillar_1.11.1        pkgconfig_2.0.3     
#>   plyr_1.8.9           pROC_1.19.0.1        prodlim_2025.4.28   
#>   progressr_0.17.0     proxy_0.4.27         purrr_1.1.0         
#>   R6_2.6.1             RColorBrewer_1.1.3   Rcpp_1.1.0          
#>   recipes_1.3.1        reshape2_1.4.4       rlang_1.1.6         
#>   rpart_4.1.24         S7_0.2.0             scales_1.4.0        
#>   shape_1.4.6.1        sparsevctrs_0.3.4    splines_4.5.1       
#>   SQUAREM_2021.1       stats_4.5.1          stats4_4.5.1        
#>   stringi_1.8.7        stringr_1.5.2        survival_3.8.3      
#>   tibble_3.3.0         tidyr_1.3.1          tidyselect_1.2.1    
#>   timechange_0.3.0     timeDate_4051.111    tools_4.5.1         
#>   tzdb_0.5.0           utf8_1.2.6           utils_4.5.1         
#>   vctrs_0.6.5          viridisLite_0.4.2    withr_3.0.2

参考文献

Hibon, Michèle, and Theodoros Evgeniou. 2005. “To Combine or Not to Combine: Selecting Among Forecasts and Their Combinations.” International Journal of Forecasting 21 (1): 15–24. https://doi.org/10.1016/j.ijforecast.2004.05.002.

Hyndman, Rob J., and Anne B. Koehler. 2006. “Another Look at Measures of Forecast Accuracy.” International Journal of Forecasting 22 (4): 679–88. https://doi.org/10.1016/j.ijforecast.2006.03.001.

Kabacoff, Robert. 2011. R in Action. Manning Publications. https://book.douban.com/subject/6126331/.

Martínez, Francisco, María Pilar Frías, María Dolores Pérez, and Antonio Jesús Rivera. 2019. “A Methodology for Applying k-Nearest Neighbor to Time Series Forecasting.” Artificial Intelligence Review 52 (3): 2019–37. https://doi.org/10.1007/s10462-017-9593-z.

Weizhong Yan. 2012. “Toward Automatic Time-Series Forecasting Using Neural Networks.” IEEE Transactions on Neural Networks and Learning Systems 23 (7): 1028–39. https://doi.org/10.1109/TNNLS.2012.2198074.


  1. 换句话说,你不能只听邻居的片面之词啊! ↩︎

  2. 换句话说,了解他你也不用找外国人打听啊,你不妨多打听下他周围的人意见。 ↩︎

  3. 笔者也尝试了利用 z-分数标准化对原数据进行处理,根据 Kappa 统计量确定最优 k 值为 3。 ↩︎