使用KNN算法识别手写数字——简记

    xiaoxiao2022-07-06  202

    代码来源:

    该代码来自:《机器学习实战》第二章K-近邻算法P31页

    digits 文件下数据格式分析:

    训练数据的手写体数字个数为 1934 测试的手写体数字个数为 946 该目录下的文件按照规则命名,如文件9_45.txt的分类是9,它是数字9的第45个实例。

    代码分析:

    该算法的执行效率不高,因为该算法需要为每个测试向量做2000次距离计算,每个距离计算包括1024个维度浮点运算,总共要做900次。 K-近邻算法是基于实例的学习,使用算法时我们必须有接近实际数据的训练样本数 据。K-近邻算法必须保存全部数据集,如果训练数据集的很大,必须使用大量的存储空间。此外, 由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。 K-近邻算法的另一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均 实例样本和典型实例样本具有什么特征.

    进行优化的方法:

    K决策树就是KNN算法的优化版,可以节省大量计算开销

    KNN的手写体数字识别代码如下:

    ''' kNN: k Nearest Neighbors Input: inX: vector to compare to existing dataset (1xN) dataSet: size m data set of known vectors (NxM) labels: data set labels (1xM vector) k: number of neighbors to use for comparison (should be an odd number) Output: the digit label ''' from numpy import * import operator from os import listdir # KNN分类方法 def classify0(inX, dataSet, labels, k): # 距离计算——使用欧氏距离,计算两个向量点的距离 dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 # 把距离从小到大排序 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): # 统计距离最 近的 k 个 数据的 标签分类出现次数 voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 对 标签分类出现次数 进行 从大到小 排序 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] # 这里把 32 * 32 的 二进制图像矩阵 转换为 1 * 1024 的向量 def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('digits/trainingDigits') #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] # 获取文件名 fileStr = fileNameStr.split('.')[0] #take off .txt # 获取标签数值 classNumStr = int(fileStr.split('_')[0]) # 待识别图片数字的 标签数值 保存在 hwLabels 中 hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr) testFileList = listdir('digits/testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr) # 调用KNN分类方法 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)) if(classifierResult != classNumStr): errorCount += 1.0 print("\n the total number of errors is: %d" % errorCount) print("\n the total error rate is: %f" % (errorCount/float(mTest))) print("测试的手写体数字个数为 %d " % mTest) # 调用方法 handwritingClassTest()

    手写体数据和KNN实现代码csdn下载链接

    最新回复(0)