本博客属于原创,若转载请标明转载出处:https://blog.csdn.net/qq_44091004/article/details/90573495 小编在开发的过程中就省去了ANN的原理部分,所以小编认为大家是有一定基础的。本案例是ANN-预测手写数字的案例。 本文需要的开发环境是新建一个工程(new project),这个工程里已经配置了OpenCV开发环境。 第一步:首先就是数据集材料的从哪里来,小编这里有一段程序是可以直接运行的App,作用是通过一个面板手绘数字,然后通过缩小的功能将图片都缩小为12*12的尺寸,然后存在自己指定的盘中,比如小编这里将手写图片数据都存在E://opencv3.1//samples//ocr中。从下图中可以看到,我的源数据集从何而来。
第二步:建立手写数字的资料库。 该数据库里包含五十个训练数据,每个数字包含五个手写体,还有一个测试数据sample7。下面的代码中有一些需要注意的地方,首先是将图片这个二维矩阵数据转换为行样本存储在Mat中,ANN中每个样本的标签用一个1行10列的矩阵来存储。
package com; import org.opencv.core.Core; import org.opencv.core.CvType; import org.opencv.core.Mat; import org.opencv.imgcodecs.Imgcodecs; public class OcrDatabase { static {System.loadLibrary(Core.NATIVE_LIBRARY_NAME);}; Mat trainingLabelsFloatMat = new Mat(50,10,CvType.CV_32FC1);//创建训练数据集的标签存储矩阵 Mat trianingDataMat = new Mat(50,144,CvType.CV_32FC1) ;//创建训练数据集的存储矩阵 Mat testingDataMat = new Mat(10,144,CvType.CV_32FC1);//创建测试数据集的存储矩阵 float[] testingLabels = {0,1,2,3,4,5,6,7,8,9}; float[][] trainingLabelsFloat = { {1,0,0,0,0,0,0,0,0,0}, {1,0,0,0,0,0,0,0,0,0}, {1,0,0,0,0,0,0,0,0,0}, {1,0,0,0,0,0,0,0,0,0}, {1,0,0,0,0,0,0,0,0,0}, {0,1,0,0,0,0,0,0,0,0}, {0,1,0,0,0,0,0,0,0,0}, {0,1,0,0,0,0,0,0,0,0}, {0,1,0,0,0,0,0,0,0,0}, {0,1,0,0,0,0,0,0,0,0}, {0,0,1,0,0,0,0,0,0,0}, {0,0,1,0,0,0,0,0,0,0}, {0,0,1,0,0,0,0,0,0,0}, {0,0,1,0,0,0,0,0,0,0}, {0,0,1,0,0,0,0,0,0,0}, {0,0,0,1,0,0,0,0,0,0}, {0,0,0,1,0,0,0,0,0,0}, {0,0,0,1,0,0,0,0,0,0}, {0,0,0,1,0,0,0,0,0,0}, {0,0,0,1,0,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0}, {0,0,0,0,0,0,1,0,0,0}, {0,0,0,0,0,0,1,0,0,0}, {0,0,0,0,0,0,1,0,0,0}, {0,0,0,0,0,0,1,0,0,0}, {0,0,0,0,0,0,1,0,0,0}, {0,0,0,0,0,0,0,1,0,0}, {0,0,0,0,0,0,0,1,0,0}, {0,0,0,0,0,0,0,1,0,0}, {0,0,0,0,0,0,0,1,0,0}, {0,0,0,0,0,0,0,1,0,0}, {0,0,0,0,0,0,0,0,1,0}, {0,0,0,0,0,0,0,0,1,0}, {0,0,0,0,0,0,0,0,1,0}, {0,0,0,0,0,0,0,0,1,0}, {0,0,0,0,0,0,0,0,1,0}, {0,0,0,0,0,0,0,0,0,1}, {0,0,0,0,0,0,0,0,0,1}, {0,0,0,0,0,0,0,0,0,1}, {0,0,0,0,0,0,0,0,0,1}, {0,0,0,0,0,0,0,0,0,1} }; Mat sample7=new Mat(1,144,CvType.CV_32FC1); public OcrDatabase() { Mat source; //assign training Mat for(int i = 0;i < 50;i++) { if(i < 10) { source = Imgcodecs.imread("E://opencv3.1//samples//ocr//0"+i+".jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE); } else{ source = Imgcodecs.imread("E://opencv3.1//samples//ocr//"+i+".jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE); } Mat temp = source.reshape(1, 144); for(int j = 0;j < 144;j++) { double[] data = new double[1]; data = temp.get(j, 0); trianingDataMat.put(i, j, data); } trainingLabelsFloatMat.put(i, 0, trainingLabelsFloat[i]); } Mat sample = Imgcodecs.imread("E://opencv3.1//samples//ocr//number71.jpg",Imgcodecs.CV_LOAD_IMAGE_GRAYSCALE); Mat tempSample7=sample.reshape(1,144); for(int j=0;j<144;j++){ double[] data=new double[1]; data=tempSample7.get(j, 0); sample7.put(0, j, data); } } public Mat getTrainingDataMat() { return trianingDataMat; } public void setTrainingDataMat(Mat trainingDataMat) { this.trianingDataMat = trainingDataMat; } public Mat getTrainingLabelsFloatMat() { return trainingLabelsFloatMat; } public void setTrainingLabelsFloatMat(Mat trainingLabelsFloatMat) { this.trainingLabelsFloatMat = trainingLabelsFloatMat; } }第三步:测试类
import org.opencv.core.Core; import org.opencv.ml.ANN_MLP; import org.opencv.ml.Ml; import org.opencv.core.CvType; import org.opencv.core.Mat; import org.opencv.core.TermCriteria; public class KNN { static{ System.loadLibrary(Core.NATIVE_LIBRARY_NAME); } public static void main(String[] args) { OcrDatabase ocr=new OcrDatabase();//创建数据集对象,相当于做试验前准备好的实验数据,封装了起来,以便调用 ANN_MLP ann=ANN_MLP.create();//创建ANN //设置网络的模型,首先创建网络模型对象,然后创建数组,将数组中的数据存储到Mat对象中,分别包含输入层,隐含层和输出层,其中输入 //层为144个特征,隐含层包括两层,分别有20和10个神经元,输出层为10个数据 Mat layerSize=new Mat(4,1,CvType.CV_32SC1); int[] layerSizeAry={144, 20,10, 10}; layerSize.put(0,0,layerSizeAry[0]); layerSize.put(1,0,layerSizeAry[1]); layerSize.put(2,0,layerSizeAry[2]); layerSize.put(3,0,layerSizeAry[3]); ann.setLayerSizes(layerSize); ann.setTrainMethod(ann.BACKPROP);//设置训练方法:误差反向传播 TermCriteria criteria=new TermCriteria(TermCriteria.MAX_ITER|TermCriteria.EPS, 300, 0.001);//创建对象,设置标准属性包括最大迭代次数和误差率 ann.setTermCriteria(criteria); ann.setActivationFunction(ann.SIGMOID_SYM);//设置激励函数为SIGMOLD类型 boolean r=ann.train(ocr.getTrainingDataMat(), Ml.ROW_SAMPLE, ocr.getTrainingLabelsFloatMat());//训练函数,样本类型为行样本,然后获取训练数据集和标签进行训练 System.out.println("是否有训练成功="+r); //测试sample7 float result7= ann.predict(ocr.sample7); System.out.println("预测7结果="+result7); } }运行结果如下图所示: 最后,作者参考的书籍: 《opencv3 使用java开发手册》