手写识别系统
说明:
将数据集文件 ‘digits.zip’ 解压至当前文件夹
定义将图像转换为向量函数
import numpy
as np
import operator
from os
import listdir
def img2vector(filename
):
returnVect
= np
.zeros
((1, 1024))
fr
= open(filename
)
for i
in range(32):
lineStr
= fr
.readline
()
for j
in range(32):
returnVect
[0, 32*i
+j
] = int(lineStr
[j
])
return returnVect
returnVect
=img2vector
("./digits/trainingDigits/0_0.txt")
returnVect
[0,0:31]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
定义 k 近邻算法
def classify0(inX
, dataSet
, labels
, k
):
dataSetSize
= dataSet
.shape
[0]
diffMat
= np
.tile
(inX
, (dataSetSize
, 1)) - dataSet
sqDiffMat
= diffMat
**2
sqDistances
= sqDiffMat
.sum(axis
=1)
distances
= sqDistances
**0.5
sortedDistIndicies
= distances
.argsort
()
classCount
= {}
for i
in range(k
):
voteIlabel
= labels
[sortedDistIndicies
[i
]]
classCount
[voteIlabel
] = classCount
.get
(voteIlabel
, 0) + 1
sortedClassCount
= sorted(classCount
.items
(), key
=operator
.itemgetter
(1), reverse
=True)
return sortedClassCount
[0][0]
定义手写数字识别系统函数
def handwritingClassTest():
hwLabels
= []
trainingFileList
= listdir
('./digits/trainingDigits')
m
= len(trainingFileList
)
trainingMat
= np
.zeros
((m
, 1024))
for i
in range(m
):
fileNameStr
= trainingFileList
[i
]
fileStr
= fileNameStr
.split
('.')[0]
classNumStr
= int(fileStr
.split
('_')[0])
hwLabels
.append
(classNumStr
)
trainingMat
[i
, :] = img2vector
('./digits/trainingDigits/%s' % fileNameStr
)
testFileList
= listdir
('./digits/testDigits')
errorCount
= 0.0
mTest
= len(testFileList
)
for i
in range(mTest
):
fileNameStr
= testFileList
[i
]
fileStr
= fileNameStr
.split
('.')[0]
classNumStr
= int(fileStr
.split
('_')[0])
vectorUnderTest
= img2vector
('./digits/testDigits/%s' % fileNameStr
)
classifierResult
= classify0
(vectorUnderTest
, trainingMat
, hwLabels
, 1)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult
, classNumStr
))
if (classifierResult
!= classNumStr
):
errorCount
+= 1.0
print("\nthe total number of errors is: %d" % errorCount
)
print("\nthe total error rate is: %f" % (errorCount
/float(mTest
)))
handwritingClassTest
()
参考文档: