C++版OpenCV使用神经网络ANN进行mnist手写数字识别[通俗易懂]

C++版OpenCV使用神经网络ANN进行mnist手写数字识别[通俗易懂]说起神经网络,很多人以为只有Keras或者tensorflow才支持,其实OpenCV也支持神经网络的,下面就使用OpenCV的神经网络进行手写数字识别,训练10次的准确率就高达96%。环境准备:vs2015OpenCV4.5.0以下为ANN神经网络的训练代码:#include<iostream>#include<opencv.hpp>#include<string>#include<fstream>usingnamespacestd

大家好,又见面了,我是你们的朋友全栈君。

说起神经网络,很多人以为只有Keras或者tensorflow才支持,其实OpenCV也支持神经网络的,下面就使用OpenCV的神经网络进行手写数字识别,训练10次的准确率就高达96%。
环境准备:
vs2015
OpenCV4.5.0
以下为ANN神经网络的训练代码:

#include<iostream>
#include<opencv.hpp>
#include <string>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::ml;


//小端存储转换
int reverseInt(int i);
//读取image数据集信息
Mat read_mnist_image(const string fileName);
//读取label数据集信息
Mat read_mnist_label(const string fileName);
//将标签数据改为one-hot型
Mat one_hot(Mat label, int classes_num);

string train_images_path = "G:/vs2015_opencv_ml/mnist/train-images.idx3-ubyte";
string train_labels_path = "G:/vs2015_opencv_ml/mnist/train-labels.idx1-ubyte";
string test_images_path = "G:/vs2015_opencv_ml/mnist/t10k-images.idx3-ubyte";
string test_labels_path = "G:/vs2015_opencv_ml/mnist/t10k-labels.idx1-ubyte";

int main()
{ 
   
	/* ---------第一部分:训练数据准备----------- */
	//读取训练标签数据 (60000,1) 类型为int32
	Mat train_labels = read_mnist_label(train_labels_path);
	//ann神经网络的标签数据需要转为one-hot型
	train_labels = one_hot(train_labels, 10);

	//读取训练图像数据 (60000,784) 类型为float32 数据未归一化
	Mat train_images = read_mnist_image(train_images_path);
	//将图像数据归一化
	train_images = train_images / 255.0;

	//读取测试数据标签(10000,1) 类型为int32 测试标签不用转为one-hot型
	Mat test_labels = read_mnist_label(test_labels_path);

	//读取测试数据图像 (10000,784) 类型为float32 数据未归一化
	Mat test_images = read_mnist_image(test_images_path);
	//归一化
	test_images = test_images / 255.0;

	/* ---------第二部分:构建ann训练模型并进行训练----------- */
	cv::Ptr<cv::ml::ANN_MLP> ann = cv::ml::ANN_MLP::create();
	//定义模型的层次结构 输入层为784 隐藏层为64 输出层为10
	Mat layerSizes = (Mat_<int>(1, 3) << 784, 64, 10);
	ann->setLayerSizes(layerSizes);
	//设置参数更新为误差反向传播法
	ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001, 0.1);
	//设置激活函数为sigmoid
	ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1.0, 1.0);
	//设置跌打条件 最大训练次数为100
	ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER | TermCriteria::EPS, 10, 0.0001));

	//开始训练
	cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(train_images, cv::ml::ROW_SAMPLE,train_labels);
	cout << "开始进行训练..." << endl;
	ann->train(train_data);
	cout << "训练完成" << endl;

	/* ---------第三部分:在测试数据集上预测计算准确率----------- */
	Mat pre_out;
	//返回值为第一个图像的预测值 pre_out为整个batch的预测值集合
	cout << "开始进行预测..." << endl;
	float ret = ann->predict(test_images, pre_out);
	cout << "预测完成" << endl;

	//计算准确率
	int equal_nums = 0;
	for (int i = 0; i < pre_out.rows; i++)
	{ 
   
		//获取每一个结果的最大值所在下标
		Mat temp = pre_out.rowRange(i, i + 1);
		double maxVal = 0;
		cv::Point maxPoint;
		cv::minMaxLoc(temp,NULL, &maxVal,NULL, &maxPoint);
		int max_index = maxPoint.x;
		int test_index = test_labels.at<int32_t>(i, 0);
		if (max_index == test_index)
		{ 
   
			equal_nums++;
		}
	}
	float acc = float(equal_nums) / float(pre_out.rows);
	cout << "测试数据集上的准确率为:" << acc * 100 << "%" << endl;
	//保存模型
	ann->save("mnist_ann.xml");


	getchar();
	return 0;
}


;

int reverseInt(int i) { 
   
	unsigned char c1, c2, c3, c4;

	c1 = i & 255;
	c2 = (i >> 8) & 255;
	c3 = (i >> 16) & 255;
	c4 = (i >> 24) & 255;

	return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}

Mat read_mnist_image(const string fileName) { 
   
	int magic_number = 0;
	int number_of_images = 0;
	int n_rows = 0;
	int n_cols = 0;

	Mat DataMat;

	ifstream file(fileName, ios::binary);
	if (file.is_open())
	{ 
   
		cout << "成功打开图像集 ..." << endl;

		file.read((char*)&magic_number, sizeof(magic_number));//幻数(文件格式)
		file.read((char*)&number_of_images, sizeof(number_of_images));//图像总数
		file.read((char*)&n_rows, sizeof(n_rows));//每个图像的行数
		file.read((char*)&n_cols, sizeof(n_cols));//每个图像的列数

		magic_number = reverseInt(magic_number);
		number_of_images = reverseInt(number_of_images);
		n_rows = reverseInt(n_rows);
		n_cols = reverseInt(n_cols);
		cout << "幻数(文件格式):" << magic_number
			<< " 图像总数:" << number_of_images
			<< " 每个图像的行数:" << n_rows
			<< " 每个图像的列数:" << n_cols << endl;

		cout << "开始读取Image数据......" << endl;

		DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
		for (int i = 0; i < number_of_images; i++) { 
   
			for (int j = 0; j < n_rows * n_cols; j++) { 
   
				unsigned char temp = 0;
				file.read((char*)&temp, sizeof(temp));
				//可以在下面这一步将每个像素值归一化
				float pixel_value = float(temp);
				//按照行将像素值一个个写入Mat中
				DataMat.at<float>(i, j) = pixel_value;
			}
		}

		cout << "读取Image数据完毕......" << endl;

	}
	file.close();
	return DataMat;
}

Mat read_mnist_label(const string fileName) { 
   
	int magic_number;
	int number_of_items;

	Mat LabelMat;

	ifstream file(fileName, ios::binary);
	if (file.is_open())
	{ 
   
		cout << "成功打开标签集 ... " << endl;

		file.read((char*)&magic_number, sizeof(magic_number));
		file.read((char*)&number_of_items, sizeof(number_of_items));
		magic_number = reverseInt(magic_number);
		number_of_items = reverseInt(number_of_items);

		cout << "幻数(文件格式):" << magic_number << " ;标签总数:" << number_of_items << endl;

		cout << "开始读取Label数据......" << endl;
		//CV_32SC1代表32位有符号整型 通道数为1
		LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
		for (int i = 0; i < number_of_items; i++) { 
   
			unsigned char temp = 0;
			file.read((char*)&temp, sizeof(temp));
			LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
		}
		cout << "读取Label数据完毕......" << endl;

	}
	file.close();
	return LabelMat;
}


//将标签数据改为one-hot型
Mat one_hot(Mat label, int classes_num)
{ 
   
	//[2]->[0 1 0 0 0 0 0 0 0 0]
	int rows = label.rows;
	Mat one_hot = Mat::zeros(rows, classes_num, CV_32FC1);
	for (int i = 0; i < label.rows; i++)
	{ 
   
		int index = label.at<int32_t>(i, 0);
		one_hot.at<float>(i, index) = 1.0;
	}
	return one_hot;
}

执行代码,训练结果如下:

成功打开标签集 ...
幻数(文件格式):2049  ;标签总数:60000
开始读取Label数据......
读取Label数据完毕......
成功打开图像集 ...
幻数(文件格式):2051 图像总数:60000 每个图像的行数:28 每个图像的列数:28
开始读取Image数据......
读取Image数据完毕......
成功打开标签集 ...
幻数(文件格式):2049  ;标签总数:10000
开始读取Label数据......
读取Label数据完毕......
成功打开图像集 ...
幻数(文件格式):2051 图像总数:10000 每个图像的行数:28 每个图像的列数:28
开始读取Image数据......
读取Image数据完毕......
开始进行训练...
训练完成
开始进行预测...
预测完成
测试数据集上的准确率为:96.26%

从上可知,使用ANN神经网络仅仅训练10次,就可以达到96.24%的识别率,增大训练次数,这个识别率还会提高,而且ann的模型文件非常小,才一兆多一点,由此可知,ANN模型非常适合端上部署。
在这里插入图片描述
使用ann的模型文件识别OpenCV加载的手写数字图片,代码如下:

#include<iostream>
#include<opencv.hpp>
using namespace std;
using namespace cv;

int main()
{ 
   
	//读取一张手写数字图片(28,28)
	Mat image = cv::imread("shuzi1.jpg", 0);
	Mat img_show = image.clone();
	//更换数据类型有uchar->float32
	image.convertTo(image, CV_32F);
	//归一化
	image = image / 255.0;
	//(1,784)
	image = image.reshape(1, 1);
	
	//加载ann模型
	cv::Ptr<cv::ml::ANN_MLP> ann= cv::ml::StatModel::load<cv::ml::ANN_MLP>("mnist_ann.xml");
	//预测图片
	Mat pre_out;
	float ret = ann->predict(image,pre_out);
	double maxVal = 0;
	cv::Point maxPoint;
	cv::minMaxLoc(pre_out, NULL, &maxVal, NULL, &maxPoint);
	int max_index = maxPoint.x;
	cout << "图像上的数字为:" << max_index << " 置信度为:" << maxVal << endl;

	cv::imshow("img", img_show);
	cv::waitKey(0);
	getchar();
	return 0;
}

执行以上代码,结果如下:
在这里插入图片描述
由此可见,使用该ANN模型能正确识别手写数字,并且ANN模型由于保存的是权重参数,因此模型文件极小,非常适合在端上进行部署。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/147908.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 删除链表倒数第n个节点_求链表的倒数第m个元素

    删除链表倒数第n个节点_求链表的倒数第m个元素原题链接给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。进阶:你能尝试使用一趟扫描实现吗?示例 1:输入:head = [1,2,3,4,5], n = 2输出:[1,2,3,5]示例 2:输入:head = [1], n = 1输出:[]示例 3:输入:head = [1,2], n = 1输出:[1]提示:链表中结点的数目为 sz1 <= sz <= 300 <= Node.val <= 1001 <= n <= s

    2022年8月8日
    9
  • Prometheus TSDB存储原理

    Prometheus TSDB存储原理Python 微信订餐小程序课程视频 https blog csdn net m0 article details Python 实战量化交易理财系统 https blog csdn net m0 article details Prometheus 包含一个存储在本地磁盘的时间序列数据库 同时也支持与远程存储系统集成 比如 grafanacloud 提供的免费云存储 API 只需将 remote write 接口信息填写在 Prome

    2025年7月16日
    5
  • oracle删除表空间语句「建议收藏」

    oracle删除表空间语句「建议收藏」–删除空的表空间,但是不包含物理文件droptablespacetablespace_name;–删除非空表空间,但是不包含物理文件droptablespacetablespace_nameincludingcontents;–删除空表空间,包含物理文件drop

    2025年7月21日
    2
  • setfacl 权限导出_linux学习-setfacl设置特定目录用户权限

    setfacl 权限导出_linux学习-setfacl设置特定目录用户权限需求:设置用户test,test1对特定的目录有读写执行权限,后加的文件也是这个权限。-R表示递归-m表示设置文件acl规则setfacl-R-md:u:test:rwx/data2/testsetfacl-R-md:u:test1:rwx/data2/test–删除ACL规则使用-bsetfacl-R-b/data2/test上面的d:u:详见如下,而perms对应的是…

    2022年6月22日
    38
  • Wallpaper Engine 占用GPU过高解决办法「建议收藏」

    Wallpaper Engine 占用GPU过高解决办法「建议收藏」看到本文的时候,首先你要有一个大致认识:Wallpaper中的壁纸大致分为两种:一种是实时计算渲染的,一种是视频播放渲染的。当你明白这一点的时候就不难解释为什么有的壁纸不大,但是却给人一种挖矿的感觉,有的壁纸很大却完美运行。。。。目录吐槽:解决办法:总结吐槽:今天找到了一个很好看(屌丝)的壁纸,结果应用起来,却发现电脑卡顿严重(见下图),虽说我的显卡1650不是很好,可也不至于带不动个20多MB的壁纸吧???于是乎……..我发现是我想简单了,他这个壁纸是..

    2022年6月17日
    1.0K
  • C++中的string类用法简介

    C++中的string类用法简介本文主要介绍C++中的string类的常见用法。1.概述string是C++标准库的一个重要的部分,主要用于字符串处理。可以使用输入输出流方式直接进行string操作,也可以通过文件等手段进行string操作。同时,C++的算法库对string类也有着很好的支持,并且string类还和c语言的字符串之间有着良好的接口。2.常见用法2.1string转换为char*方法一:…

    2022年6月13日
    29

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号