k近邻

k近邻（k-Nearest Neighbors）采用向量空间模型来分类，是一种常用的监督学习方法。它的工作原理为：给定测试样本，基于某种距离度量找出训练集中与其最靠近的k个训练样本，然后基于这k个“邻居”的信息来进行预测。通常，在分类任务中可使用“投票法”，即选择这k个样本中出现最多的类别标记作为预测结果；在回归任务中可使用“平均法”，即将这k个样本的实值输出标记的平均值作为预测结果；还可基于距离远近进行加权平均或加权投票，距离越近的样本权重越大。

k近邻没有显式的训练过程，是“懒惰学习”的代表。此类学习技术在训练阶段仅仅是把样本保存起来，训练时间开销为零，待收到测试样本后再进行处理。

示例

%matplotlib inline
import time
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.datasets.samples_generator import make_circles

N=210
K=2
MAX_ITERS = 1000
cut=int(N*0.7)

# 生成训练和测试数据集
data, features = make_circles(n_samples=N, shuffle=True, noise= 0.12, factor=0.4)
tr_data, tr_features= data[:cut], features[:cut]
te_data,te_features=data[cut:], features[cut:]

fig, ax = plt.subplots()
ax.scatter(tr_data.transpose()[0], tr_data.transpose()[1], marker = 'o', s = 100, c = tr_features, cmap=plt.cm.coolwarm )
ax.set_title('Train data')
plt.show()

start = time.time()

points=tf.Variable(data)
cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))

sess = tf.Session()
sess.run(tf.initialize_all_variables())

te_learned_features=[]
for i, j in zip(te_data, te_features):
distances = tf.reduce_sum(tf.square(tf.sub(i , tr_data)),reduction_indices=1)
neighbor = tf.arg_min(distances,0)

#print tr_features[sess.run(neighbor)]
te_learned_features.append(tr_features[sess.run(neighbor)])

accuracy = tf.reduce_mean(tf.cast(tf.equal(te_learned_features, te_features), "float"))

fig, ax = plt.subplots()
ax.scatter(te_data.transpose()[0], te_data.transpose()[1], marker = 'o', s = 100, c = te_learned_features, cmap=plt.cm.coolwarm )
ax.set_title('Test result')
plt.show()

end = time.time()
print ("Found in %.2f seconds" % (end-start))
print "Cluster assignments:", test
print "Accuracy:", sess.run(accuracy)


Found in 6.73 seconds
Cluster assignments: [0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0]
Accuracy: 1.0

© Pengfei Ni all right reserved，powered by GitbookUpdated at 2018-03-04 20:56:37