Skip to content

1.6. 最近邻方法(Nearest Neighbors)

sklearn.neighbors提供了基于邻居的(neighbors-based)的无监督学习和监督学习的方法。无监督的最近邻是许多其他方法的基础,尤其是流行学习(manifold learning)和谱聚类(spectral clustering)。基于最近邻的监督学习分为两种:分类:针对的是具有离散标签的数据;回归:针对的是具有连续标签的数据。

最近邻方法的原理是找到距离新样本最近的预定义数量的训练样本,然后从这几个已知标签的样本点中预测新样本的标签。样本数可以是用户定义的常数(K-最近邻学习),也可以根据点的局部密度而有所不同(基于半径的最近邻学习)。距离通常可以作为测度标准:标准欧几里德距离是最常见的选择。基于最近邻方法被称为非通用(non-generalizing)机器学习方法,因为它们只是“记住”其所有训练数据(可能转换为快速索引结构,例如 Ball TreeKD Tree)。

尽管这个算法很简单,但最近邻方法已成功解决了许多分类和回归问题,包括手写数字和卫星图像场景。作为非参数方法,它通常成功应用于在决策边界非常不规则的分类情景下。

sklearn.neighbors可以将NumPy数组或 scipy.sparse矩阵作为输入。对于稠密矩阵,大多数可能的距离测度都是支持的。对于稀疏矩阵,支持使用任意Minkowski测度进行搜索。

有许多学习例程都依赖于最近邻方法。期中一个例子是核密度估计,该例子将在密度估计章节中进行讨论。

1.6.1. 无监督最近邻(Unsupervised Nearest Neighbors)

NearestNeighbors实现无监督的最近邻算法。它是三种不同的最近邻算法的统一接口:BallTreeKDTree,以及基于sklearn.metrics.pairwise的暴力搜索算法(brute-force search)。选择那种最近邻搜索算法是通过关键字'algorithm'来控制的,该关键字 'algorithm' 的取值必须是 ['auto', 'ball_tree', 'kd_tree', 'brute'] 其中的一个。当设置为 'auto'时,算法会尝试从训练数据中确定最佳方法。关于上述每个取值的优缺点的讨论,请参阅《最近邻算法》。

警告:

关于最近邻居算法,如果有两个相邻的值 $k+1$ 和 $k$ ,它们具有相同的距离但是标签不同的话,运行结果将取决于训练数据的顺序。

1.6.1.1. 寻找最近的邻居

对于找到两组数据集中最近邻点的简单任务,可以使用sklearn.neighbors中的无监督算法:

>>> from sklearn.neighbors import NearestNeighbors
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)
>>> distances, indices = nbrs.kneighbors(X)
>>> indices
array([[0, 1],
       [1, 0],
       [2, 1],
       [3, 4],
       [4, 3],
       [5, 4]]...)
>>> distances
array([[0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.41421356],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.41421356]])

因为查询集(query set)与训练集(training set)匹配,所以每个点的最近邻点是其自身,距离为0。

还可以有效地生成一个稀疏图(sparse graph)来标识相连点之间的连接情况:

>>> nbrs.kneighbors_graph(X).toarray()
array([[1., 1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 0.],
       [0., 0., 0., 1., 1., 0.],
       [0., 0., 0., 0., 1., 1.]])

由于数据集是结构化的,因此按索引顺序的相邻点在参数空间也是相邻点,从而生成了近似K-近邻的块对角矩阵(block-diagonal matrix)。 这种稀疏图在各种利用样本点之间的空间关系进行无监督学习的情况下都很有用:请查看sklearn.manifold.Isomapsklearn.manifold.LocallyLinearEmbeddingsklearn.cluster.SpectralClustering

1.6.1.2. KDTree类 和 BallTree类

另外,我们还可以直接使用KDTreeBallTree类来查找最近邻。这是上面提到过的NearestNeighbors所封装的功能。KDTree 和 BallTree 具有相同的接口;我们将在此处展示如何使用KDTree:

>>> from sklearn.neighbors import KDTree
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> kdt = KDTree(X, leaf_size=30, metric='euclidean')
>>> kdt.query(X, k=2, return_distance=False)
array([[0, 1],
       [1, 0],
       [2, 1],
       [3, 4],
       [4, 3],
       [5, 4]]...)

有关最近邻搜索的可用参数选项的更多信息,请参阅KDTreeand BallTree文档,其中包括查询策略和距离测度等内容。有关可用测度的列表,请参见DistanceMetric 文档。

1.6.2. 最近邻分类

基于最近邻的分类有基于实例的学习(instance-based learning)非概括性学习(non-generalizing learning):它并不试图构建一个一般的内部模型,而只是存储训练数据的实例。分类是根据每个点的最近邻的简单多数票计算得出的:一个查询点被分配给数据类,该类在该点最近的邻居中有最多的代表。

scikit-learn实现两个不同的最近邻分类器: KNeighborsClassifier实现了基于每个查询点的$k$个最近邻进行学习的方法, 其中$k$是用户指定的整数值。 RadiusNeighborsClassifier实现了基于固定半径内的邻居进行学习的方法,其中$r$ 是用户指定的浮点值。

KNeighborsClassifier中的$k$-近邻分类($k$-neighbors classification)是最常用的技术。$k$值的最优选择是与数据高度依赖的:一般来说$k$值越大就越能够消除噪声带来的影响,但是会使得分类边界变得越不清晰。

如果数据是不均匀采样的,则基于半径的近邻分类RadiusNeighborsClassifier可能是更好的选择。用户指定一个固定半径$r$,以便在越稀疏的邻域内的点可以使用越少的最近邻点进行分类。对于高维参数空间,这个方法就会因为维数灾难而变得越来越没有效率。

基本的最近邻分类使用统一的权重:即,分配给查询点的标签值是从其最近邻的简单多数投票中计算出来的。 在某些环境下,最好对邻居进行加权,使得离得越近的点对最终的预测做出越大的贡献。这可以通过关键字 weights 来实现。 该参数的默认值 weights = 'uniform' 给每个邻居分配均匀的权重。而 weights = 'distance' 给每个邻居分配的权重是每个邻居点到查询点的距离的倒数。 除此之外,用户还可以传递自定义的距离函数来计算权重:

分类_1 分类_2

案例:

1.6.3. 最近邻回归

在数据标签是连续变量而不是离散变量的情况下,可以使用基于最近邻的回归算法。分配给查询点的标签是基于其最近邻的标签的平均值的计算而来的。

scikit-learn实现了两个不同的最近邻回归器: KNeighborsRegressor实现了基于每个查询点的$k$个最近邻进行学习的方法, 其中$k$是用户指定的整数值。 RadiusNeighborsRegressor实现了基于固定半径内的邻居进行学习的方法,其中$r$ 是用户指定的浮点值。

基本的最近邻回归使用统一的权重:也就是说,分配给查询点的标签值是从其最近邻的简单多数投票中计算出来的。 在某些环境下,最好对邻居进行加权,使得离得越近的点对最终的预测做出越大的贡献。这可以通过 weights 关键字来实现。 该参数的默认值 weights = 'uniform' 给每个邻居分配均匀的权重。而 weights = 'distance' 给每个邻居分配的权重是每个邻居点到查询点的距离的倒数。 除此之外,用户还可以传递自定义的距离函数来计算权重:

https://scikit-learn.org/stable/_images/sphx_glr_plot_regression_0011.png

多输出最近邻回归的使用案例可以看使用多输出估计器补全人脸。在这个案例中,输入 X 是一些人脸的上半部分, 输出 Y 是那些人脸的下半部分。

https://scikit-learn.org/stable/_images/sphx_glr_plot_multioutput_face_completion_0011.png

案例:

1.6.4. 最近邻算法

1.6.4.1. 暴力(Brute Force)

最近邻的快速计算是机器学习研究的一个活跃领域。最简单的近邻搜索的实现涉及数据集中所有成对点之间距离的暴力计算(Brute Force computation): 对于$D$维空间中的$N$个样本来说, 这个方法的复杂度是$O[DN^2]$。 对于少量的数据样本来说,高效的暴力近邻搜索是非常有竞争力的。 然而,随着样本数$N$的增长,暴力搜索的方法很快变得不切实际了。 在 sklearn.neighbors中,暴力近邻搜索算法通过关键字参数 algorithm = 'brute' 来指定,并通过sklearn.metrics.pairwise中的例程(routines)来进行计算。

1.6.4.2. K-D树(K-D Tree)

为了解决暴力搜索算法的计算效率低下的问题,发明了大量的基于树的数据结构的算法。通常,这些结构试图通过有效地编码样本的聚合距离信息(aggregate distance information)来减少所需要的距离的计算量。基本想法是,如果点$A$离点$B$很远, 但$B$又非常接近点$C$的话,那么我们就可以知道$A$点和$C$点相距很远,而不必显式计算它们的距离。这样,可以将最近邻搜索的计算成本降低到$O[DNlog⁡(N)]$或者更低。这是对于暴力搜索在大样本数$N$中表现的显著改善。

利用这种聚合距离信息(aggregate distance information)的早期方法是K-D树(K-D Tree)数据结构(K维树(K-dimensional tree)的缩写),它将二维 四叉树(Quad-trees) 和三维 八叉树(Oct-trees) 推广到任意数量的维度. K-D树(K-D Tree)是一个二叉树结构, 它沿着数据轴递归地划分参数空间,将其划分为嵌入数据点的嵌套的各向异性区域。 K-D树(K-D Tree)的构造非常快:因为只需沿数据轴执行分区, 无需计算$D$维(D-dimensional)的距离。 一旦构建完成, 查询点的最近邻距离计算复杂度仅为$O[log⁡(N)]$。 虽然K-D树(K-D Tree)的方法对于低维度 ($D$ < 20) 近邻搜索非常快, 当$D$增长到很大时, 效率变低: 这就是所谓的 “维度灾难” 的一种体现。 在 scikit-learn 中, K-D树(K-D Tree)近邻搜索可以使用关键字 algorithm = 'kd_tree' 来指定, 并且使用KDTree来进行计算

参考文献:

1.6.4.3. Ball 树(Ball Tree)

为了解决K-D树(K-D Tree)在更高维度上低效率的问题,发明了Ball 树(Ball Tree) 数据结构。其中 K-D 树沿迪卡尔轴(即坐标轴)分割数据, 而 Ball 树在沿着一系列的嵌套的超球面(nesting hyper-spheres)来分割数据。 通过这种方法构建的树要比K-D树消耗更多的时间, 但是这种数据结构对于高结构化的数据是非常有效的, 即使在高维度上也是一样。

Ball 树将数据递归地划分到由质心$C$和半径$r$定义的节点上,以使得节点内的每个点都位于由质心$C$和半径$r$定义的超球面(hyper-sphere)内。通过使用 三角不等式(triangle inequality) 来减少近邻搜索的候选点数: $$ |x+y| \leq |x| + |y| $$ 测试点和质心之间的单一距离计算足以确定测试点到节点内所有点的距离的下限和上限。 由于Ball 树节点的球形几何(spherical geometry), 它在高维度上的性能超出 K-D树 , 尽管实际的性能高度依赖于训练数据的结构。 在 scikit-learn 中, 基于Ball 树的近邻搜索可以使用关键字 algorithm = 'ball_tree' 来指定, 并且使用sklearn.neighbors.BallTree来计算,或者, 用户可以直接使用BallTree

参考文献:

1.6.4.4. 最近邻算法的选择

为给定数据集的选择一个最佳算法是很复杂的,并且取决于许多因素:

  • 样品数量$N$(即n_samples)和维度$D$(即n_features)。

    • 暴力(Brute force)的查询时间增长复杂度是$O[DN]$
    • Ball 树 的查询时间增长复杂度是$O[D \log(N)]$
    • K-D 树 的查询时间随着$D$变化,所以很难准确的量化。对于一个小的$D$(<=20)查询代价可以近似认为是$O[Dlog⁡(N)]$, 而 K-D 树的查询时间也非常有效率。对于较大的$D$ ,查询代价几乎增加到$O[DN]$, 由于树结构引起的过载(overhead)导致查询比暴力(brute force)还要慢。

对于小型数据集($N$少于30个左右)来说, 由于$log⁡(N)$相当于$N$,所以暴力算法可能比基于树的算法更有效。KDTreeBallTree 通过提供一个参数leaf size来控制样本的数量, 一旦小于这个数量则直接使用暴力搜索进行查询。这样的做法使得这两个算法类对于较小的$N$能够达到接近暴力搜索算法的效率。

  • 数据结构:数据的固有维数(intrinsic dimensionality)稀疏性(sparsity)。数据的固有维数是数据所在的流形(manifold)的维数$d≤D$,其中数据的流形可以是线性或非线性的嵌入到参数空间里的。 数据的稀疏性是指数据填充参数空间的度(这里数据稀疏性的概念区别于稀疏矩阵的稀疏概念,数据矩阵有可能一个0都没有,但是该矩阵的结构可能仍然是稀疏的。)

    • 暴力(Brute force) 的查询时间与数据的结构无关。

    • Ball 树K-D 树 的查询时间可能会受到数据结构的影响。通常情况下,具有越小的固有维数的越稀疏的数据会带来越快的查询时间。 因为 K-D 树 的内部表示是对齐到参数坐标系轴上的,所以它不会在任意结构化的数据上与Ball 树有同样的效率提升。

机器学习中用到的数据集都是往往非常结构化,非常适合基于树的查询。

  • 一个查询点需要的邻居的数量$k$

    • 暴力(Brute force)查询时间在很大程度上不受$k$值的影响。
    • Ball 树K-D 树 的查询时间将会随着$k$的增加而越来越小。 这主要受到两方面的影响: 首先,一个较大的$k$ 值会导致搜索在参数空间中进行较大地搜索;第二,使用$k$>1需要在遍历树时对结果进行内部排队。

随着$k$相较于$N$越来越大, 基于树的查询进行剪枝的能力就会越来越小。在这种情况下,暴力搜索查询会更有效率。

  • 查询点的数量。 Ball 树和K-D 树都需要一个构建过程。 在进行多次的查询中,这个构建过程的成本可以忽略不计。 如果只执行少量的查询, 但构建成本却占总成本的很大一部分的话,暴力方法会比基于树的方法更好。

就当前来说,当你设置参数algorithm = 'auto'时,该算法会在当$k >= N/2$,并且输入数据是稀疏的时,或者effective_metric_没有在包含有'kd_tree''ball_tree'的列表VALID_METRICS中时,该算法会自动选择'brute'(暴力搜索)算法。否则,它会优先选择'kd_tree''ball_tree',其中选择'ball_tree'的条件是, effective_metric_有在包含有'ball_tree'的列表VALID_METRICS中。该选择基于以下假设:查询点的数量至少与训练点的数量相同,并且leaf_size 接近其默认值30

1.6.4.5. leaf_size 的效果

如上所述, 对于小样本来说,暴力搜索是比基于树的搜索更有效的方法。 这一事实在Ball 树和K-D 树中被用于在叶节点内部切换到暴力搜索。该切换的时机可以使用参数 leaf_size 来指定, 这个参数选择有很多的效果:

构建时间

更大的 leaf_size 会导致更快的树构建时间,因为需要创建的节点更少。

查询时间

无论 leaf_size的值的大小,它都可能会导致次于最优的查询成本。当 leaf_size 的值接近于1时,遍历节点所涉及的开销大大减慢了查询时间。 当 leaf_size 接近训练集的大小,本质上查询变得得是暴力的。这些值之间的一个很好的妥协是 leaf_size = 30, 这是该参数的默认值。

内存

随着 leaf_size 的增加,存储树结构所需的内存减少。 对于存储每个$D$维节点的Ball 树来说,这点至关重要。 Ball 树 所需的存储空间近似于 1 / leaf_size 乘以训练集的大小。

leaf_size 在暴力(brute force)查询中是没有用到的。

1.6.5. 最近质心分类器(Nearest Centroid Classifier)

NearestCentroid分类器是一个通过其成员的质心来表示每个类简单的算法。实际上,这使得它类似于 sklearn.KMeans 算法的标签更新阶段。它也没有参数选择,使其成为良好的基准分类器。 然而,在非凸类上,以及当类具有截然不同的方差时,它都会受到影响。所以这个分类器假设所有维度的方差都是相等的。 对于没有做出这个假设的更复杂的方法, 请参阅线性判别分析 (sklearn.discriminant_analysis.LinearDiscriminantAnalysis)和二次判别分析(sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis)。 NearestCentroid的默认用法很简单:

>>> from sklearn.neighbors import NearestCentroid
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> y = np.array([1, 1, 1, 2, 2, 2])
>>> clf = NearestCentroid()
>>> clf.fit(X, y)
NearestCentroid()
>>> print(clf.predict([[-0.8, -1]]))
[1]

1.6.5.1. 最近收缩质心(Nearest Shrunken Centroid)

NearestCentroid有一个shrink_threshold参数,它实现了最近收缩质心(Nearest Shrunken Centroid)分类。实际上,每个质心的每个特征的值都除以该特征的类中的方差。然后通过shrink_threshold来减小特征值。最值得注意的是,如果特定特征值超过零,则将其设置为零。实际上,这个方法移除了影响分类器的特征。 这很有用,例如,去除噪声特征。

在以下例子中, 使用一个较小的 shrink 阀值将模型的准确度从 0.81 提高到 0.82。

nearest_centroid_1 nearest_centroid_2

案例:

1.6.6. 最近的邻居转化器(Nearest Neighbors Transformer)

许多scikit学习估计器都依赖于最近邻算法:例如一些分类器和回归器,KNeighborsClassifierKNeighborsRegressor。还有一些聚类方法,例如 DBSCANSpectralClustering。以及一些流形嵌入,例如TSNEIsomap

所有这些估计器都可以在内部计算最近邻,并且它们大多也接受预先计算的近邻稀疏图,可以通过kneighbors_graphradius_neighbors_graph来设置。当设置 mode='connectivity'时,这些函数会根据需要返回一个二进制邻接稀疏图,例如SpectralClustering。而设置mode='distance'时,它们会根据需要返回距离稀疏图,例如DBSCAN。要将这些功能包含在scikit-learn管道(pipeline)中,可以使用相应的类 KNeighborsTransformerRadiusNeighborsTransformer。这种稀疏图API的在很多方面都是有好处的。

首先,可以多次重复使用预先计算图,例如在改变估计器的参数时候。这可以由用户手动完成,也可以使用scikit-learn管道(pipeline)的缓存属性(caching properties)来完成:

>>> from sklearn.manifold import Isomap
>>> from sklearn.neighbors import KNeighborsTransformer
>>> from sklearn.pipeline import make_pipeline
>>> estimator = make_pipeline(
...     KNeighborsTransformer(n_neighbors=5, mode='distance'),
...     Isomap(neighbors_algorithm='precomputed'),
...     memory='/path/to/cache')

其次,预先计算图可以更好地控制最近邻估计,例如通过在参数n_jobs中启用多重处理(multiprocessing),而该参数n_jobs可能并非在所有估计器中都可用。

最后,预计算可以由自定义估计器执行,以做出不同的实现,例如拟合最近邻方法或具有特殊数据类型的实现。这种预计算的稀疏图需要按照radius_neighbors_graph输出格式进行设置 :

  • CSR矩阵(尽管可以接受COO,CSC或LIL)。

  • 仅明确存储每个样本相对于训练数据的最近邻。这应该包括,与查询点的距离为0的点,计算训练数据与其自身之间最近的邻的矩阵对角线。

  • 每行的data应按递增顺序存储距离(可选。未排序的数据将被稳定排序,从而增加计算开销)。

  • 数据中的所有值均应为非负数。

  • 任何行中的indices都不应有重复项(请参阅https://github.com/scipy/scipy/issues/5807)。

  • 如果要传递预先计算的矩阵的算法是使用k最近邻(与半径邻域相对)的话,则必须在每行中至少存储k个邻居(或如下文所述的k + 1个)。

注意:

当(使用KNeighborsTransformer)查询邻居的特定数量时,n_neighbors的定义是不明确的,因为它可以把每个训练点都认为是自己的邻居,也可以认为不是。这两种选择都不是完美的,因为包含它们会导致在训练和测试期间出现不同数量的非本人邻居(non-self neighbors),而排除它们会导致fit(X).transform(X)fit_transform(X)的运行结果出现不同,这与scikit-learn API背道而驰。我们在KNeighborsTransformer定义中,每个训练点都把n_neighbors个点作为自己的邻居。但是,出于与使用其他定义的其他估计器兼容的原因,当设置mode == 'distance'时会再多计算一个邻居。为了最大限度地提高与其他所有估计器的兼容性,一个比较安全的做法是始终在自定义的最近邻估计器中包括一个额外的邻居,因为不必要的邻居将被后面的估计器过滤掉。

案例:

1.6.7. 邻域成分分析(Neighborhood Components Analysis)

邻域成分分析(NCA,NeighborhoodComponentsAnalysis)是一种距离度量学习算法,与标准欧几里得距离相比,旨在提高最近邻分类的准确性。该算法直接最大化在训练集上留一法的k近邻(KNN)(leave-one-out k-nearest neighbors)分数的随机变量。它还可以用来进行数据可视化和快速分类数据的低维线性投影的学习。

nca_illustration_1 nca_illustration_2

在上图中,我们考虑了随机生成的数据集中的一些点。我们看一下随机KNN分类在点3的表现。点3和其他点之间的链接的粗细与它们的距离成正比,可以看作是随机最近邻预测规则分配给该点的相对权重(或概率)。在原始空间中,点3具有来自不同类别的许多随机邻居,因此正确类别的可能性很小。但是,在NCA学习到的投影空间中,唯一权重不可忽略的随机邻居与点3属于同一类别,从而确保了点3将得到很好的分类。有关更多详细信息,请参见数学公式

1.6.7.1. 分类

如果与最近邻分类器(KNeighborsClassifier)结合使用的话,NCA用于分类很有吸引力。它可以处理多分类问题,而不会增加模型大小,并且不会引入需要用户进行微调的其他参数。

NCA分类已被证明在大小和难度各异的数据集中实践的效果很好。与线性判别分析等相关方法相比,NCA不对类分布进行任何假设。最近邻分类可以产生高度不规则的决策边界。

如果要使用此模型进行分类,需要将学习到最佳变换的NeighborhoodComponentsAnalysis与在投影空间中进行分类的KNeighborsClassifier结合起来 。这是使用示例:

>>> from sklearn.neighbors import (NeighborhoodComponentsAnalysis,
... KNeighborsClassifier)
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.pipeline import Pipeline
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... stratify=y, test_size=0.7, random_state=42)
>>> nca = NeighborhoodComponentsAnalysis(random_state=42)
>>> knn = KNeighborsClassifier(n_neighbors=3)
>>> nca_pipe = Pipeline([('nca', nca), ('knn', knn)])
>>> nca_pipe.fit(X_train, y_train)
Pipeline(...)
>>> print(nca_pipe.score(X_test, y_test))
0.96190476...

nca_classification_1 nca_classification_2

该图显示了在鸢尾属植物数据集上的最近邻分类和邻域成分分析分类的决策边界,当仅针对两个特征进行训练和评分(scoring)时,就可以可视化。

1.6.7.2. 降维

NCA可用于有监督的降维。把输入数据投影到线性子空间上,该线性子空间由最小化NCA目标的方向所组成。可以使用参数n_components设置所需要降维的维度。例如,下图显示了降维与主数据分析(sklearn.decomposition.PCA),线性判别分析(sklearn.discriminant_analysis.LinearDiscriminantAnalysis)和邻域分量分析(NeighborhoodComponentsAnalysis)在手写数字识别数据集上的比较,该数据集具有大小$n_{samples} = 1797$ 和 $n_{features} = 64$。将数据集分为大小相等的训练集和测试集,然后进行标准化。为了进行评估,对每种方法找到的二维投影点计算3个最近邻分类精度。每个数据样本属于10个类别之一。

nca_dim_reduction_1 nca_dim_reduction_2 nca_dim_reduction_3

案例:

1.6.7.3. 数学公式

NCA的目标是学习到一个形状是 (n_components, n_features)的最佳线性变换矩阵,该矩阵可以最大化每一个样本$i$的概率$p_i$的和。其中$p_i$是样本$i$被正确分类的概率,即: $$ \underset{L}{\arg\max} \sum\limits_{i=0}^{N - 1} p_{i} $$ 其中$N$ = n_samples,根据学习到的嵌入空间中的随机最近邻居规则,被正确分类的样本$i$的概率$p_i$为:

$$ p_{i}=\sum\limits_{j \in C_i}{p_{i j}} $$ 其中$C_i$是与样本$i$同一类中的点, $p_{ij}$是在嵌入空间中计算欧几里得距离之后并且经过softmax函数后的概率:

$$ p_{i j} = \frac{\exp(-||L x_i - L x_j||^2)}{\sum\limits_{k \ne i} {\exp{-(||L x_i - L x_k||^2)}}} , \quad p_{i i} = 0 $$

1.6.7.3.1. 马氏距离(Mahalanobis distance)

NCA可以看作是在学习(平方的)马氏距离:

$$ || L(x_i - x_j)||^2 = (x_i - x_j)^TM(x_i - x_j), $$ 其中$M=L^TL$是大小为(n_features, n_features)的对称正半定矩阵 。

1.6.7.4. 实现

该实现遵循原始论文[1]中的相关说明。对于优化方法,它当前使用scipy的L-BFGS-B并且在每次迭代时进行全梯度计算,以避免设置(tune)学习速率并提供稳定的学习。

有关NeighborhoodComponentsAnalysis.fit更多信息,请参见下面的示例和docstring 。

1.6.7.5. 复杂度

1.6.7.5.1. 训练

NCA存储成对的距离矩阵,并占用n_samples ** 2内存。时间复杂度取决于优化算法完成的迭代次数。但是,可以使用max_iter参数来设置最大迭代次数。对于每次迭代,时间复杂度为O(n_components x n_samples x min(n_samples, n_features))

1.6.7.5.2.转换(Transform)

此处transform操作返回$LX^T$,因此其时间复杂度等同于n_components * n_features * n_samples_test。该操作不会增加空间复杂度。

参考文献:

[1] “Neighbourhood Components Analysis”, J. Goldberger, S. Roweis, G. Hinton, R. Salakhutdinov, Advances in Neural Information Processing Systems, Vol. 17, May 2005, pp. 513-520.

Wikipedia entry on Neighborhood Components Analysis

©2007-2019,scikit-learn开发人员(BSD许可证)。 显示此页面源码