宽度学习(BLS)实战——python复刻MNIST数据集的数据预处理及训练过程[通俗易懂]

宽度学习(BLS)实战——python复刻MNIST数据集的数据预处理及训练过程[通俗易懂]目录1.宽度学习(BroadLearningSystem)2.MNIST数据集3.复刻MNIST数据集的预处理及训练过程1.宽度学习(BroadLearningSystem)对宽度学习的理解可见于这篇博客宽度学习(BroadLearningSystem)_颹蕭蕭的博客-CSDN博客_宽度学习这里不再做详细解释2.MNIST数据集mnist数据集官网(下载地址):MNISThandwrittendigitdatabase,YannLeCun,Cori

大家好,又见面了,我是你们的朋友全栈君。

目录

 

1.宽度学习(Broad Learning System)

2.MNIST数据集

3.复刻MNIST数据集的预处理及训练过程


1.宽度学习(Broad Learning System)

对宽度学习的理解可见于这篇博客宽度学习(Broad Learning System)_颹蕭蕭的博客-CSDN博客_宽度学习

这里不再做详细解释

2.MNIST数据

mnist数据集官网(下载地址):MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集有称手写体数据集,其中中训练集一共包含了 60,000 张图像和标签,而测试集一共包含了 10,000 张图像和标签。测试集中前5000个来自最初NIST项目的训练集.,后5000个来自最初NIST项目的测试集。前5000个比后5000个要规整,这是因为前5000个数据来自于美国人口普查局的员工,而后5000个来自于大学生。
MNIST数据集自1998年起,被广泛地应用于机器学习和深度学习领域,用来测试算法的效果,相当于该领域的”hello world!”

3.复刻MNIST数据集的预处理及训练过程

原bls代码下载地址:Broad Learning System

下载后,我先用原代码中带的数据和代码进行训练,运行结果如下:

1.不含增量的bls代码:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

2.含有增量的bls代码:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

可以看到bls训练模型的时间非常短并且精确度达到0.93以上

然后我们回过头来看它用的训练集和测试集,它共输入三个csv文件,分别为test.csv,train.csv,sample_submission.csv

其中格式为:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

 

这就是我们处理完MNIST数据之后需要bls代码中训练的数据,统计得到以下信息

数据集 数据总数
test.csv(测试集) 28000张
train.csv(训练集) 42000张

其中sample_submission.csv是提交样例,它最后会用来保存训练出的模型对测试集打的标签为csv文件。

那么得到这些信息我们就可以开始处理我们的mnist数据集了,在官网下载完数据集后我们得到了四个文件:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

这个时候如果你是初学者,你就会奇怪明明是图像数据为什么下载完会是这四个东西?

这是因为为了方便使用,官方已经将70000张图片处理之后存入了这四个二进制文件中,因此我们要对这四个文件进行解析才能看到原本的图片。

此处用到struct包进行解析,详情见于Mnist数据集简介_查里王的博客-CSDN博客_mnist数据集

解析代码:

import os
import struct
import numpy as np

# 读取标签数据集
with open('../data/train-labels.idx1-ubyte', 'rb') as lbpath:
    labels_magic, labels_num = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)

# 读取图片数据集
with open('../data/train-images.idx3-ubyte', 'rb') as imgpath:
    images_magic, images_num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(images_num, rows * cols)

# 打印数据信息
print('labels_magic is {} \n'.format(labels_magic),
      'labels_num is {} \n'.format(labels_num),
      'labels is {} \n'.format(labels))

print('images_magic is {} \n'.format(images_magic),
      'images_num is {} \n'.format(images_num),
      'rows is {} \n'.format(rows),
      'cols is {} \n'.format(cols),
      'images is {} \n'.format(images))

# 测试取出一张图片和对应标签
import matplotlib.pyplot as plt

choose_num = 1  # 指定一个编号,你可以修改这里
label = labels[choose_num]
image = images[choose_num].reshape(28, 28)

plt.imshow(image)
plt.title('the label is : {}'.format(label))
plt.show()

运行结果:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

 但是这并不是我们要的东西,我们需要的是将二进制文件解析后存入csv文件中用于训练。

在观察了原代码中所用的csv文件的格式以及bls代码中读取数据的方式后,我发现需要再存入之前对数据添加一个index,其中包括”label”和”pixel0~pixel784″,其中pixel是一维数组的元素编码,由于mnist数据集是28*28的图片,所以,转为一维数组后一共有784个元素。

知道这个原理后,编写代码如下:

import csv

def pixel(p_array, outf):
    with open(outf, "w",newline='') as csvfile:
            writer = csv.writer(csvfile)
            # 先写入columns_name
            writer.writerow(p_array)

def convert(imgf, labelf, outf, n):
    f = open(imgf, "rb")
    o = open(outf, "a")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
    images = []

    for i in range(n):
        image = [ord(l.read(1))]
        for j in range(28*28):
            image.append(ord(f.read(1)))
        images.append(image)

    for image in images:
        o.write(",".join(str(pix) for pix in image)+"\n")
    f.close()
    o.close()
    l.close()

if __name__ == '__main__':

    p_array = []
    for j in range(0, 785):
        if j == 0 :
            b1 = "label"
            p_array.append(b1)
        else:
            b1 = 'pixel' + str(j - 1)
            p_array.append(b1)
    pixel(p_array,"../data/mnist_train.csv")
    pixel(p_array,"../data/mnist_test.csv")

    convert("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte",
        "../data/mnist_train2.csv", 42000)
    convert("../data/t10k-images.idx3-ubyte", "../data/t10k-labels.idx1-ubyte",
        "../data/mnist_test2.csv", 28000)

    print("success!")

代码运行结果;

得到经过二进制文件解析以及格式处理后的数据:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

现在训练集文件格式与源代码格式一样了,但是,既然是复刻那么我们还有一个问题没有解决——数据总数不一样,根据源代码中信息,训练集有42000张,测试集28000张,但是我们的训练集有60000张,测试集有10000张,所以我们需要稍微处理一下我们数量,其实这个很简单,只要将训练集中的数据匀18000张给测试集就可以了,另外测试集中标签一行需要删除,因为测试集好比高考试卷,标签相当于答案,没有人会把高考答案告诉你然后让你考对不对。这个过程可以用python代码实现,只要加入一点点功能,编写功能代码如下:

(记得删除测试集中的标签)

import csv
def test_add(train_imgf,train_labelf,outf):
    f = open(train_imgf, "rb")
    o = open(outf, "a")
    l = open(train_labelf, "rb")
    f.read(16)
    l.read(8)
    images = []

    for i in range(42001, 60001):
        image = [ord(l.read(1))]
        for j in range(28 * 28):
            image.append(ord(f.read(1)))
        images.append(image)

    for image in images:
        o.write(",".join(str(pix) for pix in image) + "\n")
    f.close()
    o.close()
    l.close()


def pixel(p_array, outf):
    with open(outf, "w",newline='') as csvfile:
            writer = csv.writer(csvfile)
            # 先写入columns_name
            writer.writerow(p_array)

def convert(imgf, labelf, outf, n):
    f = open(imgf, "rb")
    o = open(outf, "a")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
    images = []

    for i in range(n):
        image = [ord(l.read(1))]
        for j in range(28*28):
            image.append(ord(f.read(1)))
        images.append(image)

    for image in images:
        o.write(",".join(str(pix) for pix in image)+"\n")
    f.close()
    o.close()
    l.close()

if __name__ == '__main__':

    p_array = []
    for j in range(0, 785):
        if j == 0 :
            b1 = "label"
            p_array.append(b1)
        else:
            b1 = 'pixel' + str(j - 1)
            p_array.append(b1)
    pixel(p_array,"../data/mnist_train2.csv")
    pixel(p_array,"../data/mnist_test2.csv")

    convert("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte",
        "../data/mnist_train2.csv", 42000)
    convert("../data/t10k-images.idx3-ubyte", "../data/t10k-labels.idx1-ubyte",
        "../data/mnist_test2.csv", 10000)
    test_add("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", "../data/mnist_test2.csv")

    print("success!")

处理后,与提交案例一起加入bls训练,可以得到:

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6ZW_5byT5ZCM5a2m,size_20,color_FFFFFF,t_70,g_se,x_16

可以看到这与之前原始数据训练的结果几乎相同

 

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/143425.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Java中static作用及用法详解「建议收藏」

    Java中static作用及用法详解「建议收藏」static是静态修饰符,什么叫静态修饰符呢?大家都知道,在程序中任何变量或者代码都是在编译时由系统自动分配内存来存储的,而所谓静态就是指在编译后所分配的内存会一直存在,直到程序退出内存才会释放这个空间,也就是只要程序在运行,那么这块内存就会一直存在。这样做有什么意义呢?在Java程序里面,所有的东西都是对象,而对象的抽象就是类,对于一个类而言,如果要使用他的成员,那么普通情况下必须先实例化对象后,通过对象的引用才能够访问这些成员,但是用static修饰的成员可以通过类名加“.”进行直接访问。

    2022年7月8日
    25
  • Cocos2d-x-lua游戏两个场景互相切换MainScene01切换到MainScene02「建议收藏」

    Cocos2d-x-lua游戏两个场景互相切换MainScene01切换到MainScene02

    2022年1月21日
    38
  • sqlformat数字格式化_java怎么输出数字

    sqlformat数字格式化_java怎么输出数字前言以前用到要对数字格式的地方,都是直接到网上搜一下。拿过来能用就行。因为平时用的不多。但是最近的项目对这个用的多了。网上拿来的不够用了。自己看了java源码把这方面恶补了。而且最近也好长时间没有写博客了。正好写一篇抛砖引玉吧。正文如果你对java源码比较了解。你会发现java对文字,数字的格式化,是有一个公共的父类的Format。NumberFormat和DecimalFormat都是它…

    2022年10月8日
    3
  • vue webpak版本 查看_vue版本以及webpack版本

    vue webpak版本 查看_vue版本以及webpack版本vue作为大前端的主流框架更新速度也是极快。那么vue的更新会有哪些问题呢?最近在搭建vue框架的时候发现由于vue版本的快速迭代已经与原本般配的webpack产生了隔阂。webpack作为大前端的主流打包工具如果与之不兼容,会有越来越多的麻烦事情。经过反复测试,得出结论一篇vue与webpack最佳拍档组合版本号公布。npminitnpminstallwebpack@3.10.0v…

    2022年6月1日
    432
  • MPP架构概念_体系架构是什么意思

    MPP架构概念_体系架构是什么意思MPP架构概念1.什么是MPPMPP(MassivelyParallelProcessing),即大规模并行处理。什么是并行处理?在数据库集群中,首先每个节点都有独立的磁盘存储系统和内存系统,其次业务数据根据数据库模型和应用特点划分到各个节点上,MPP是将任务并行的分散到多个服务器和节点上,在每个节点上计算完成后,将各自部分的结果汇总在一起得到最终的结果。什么是大规模?每台数据节点通过专用网络或者商业通用网络互相连接,彼此协同计算,作为整体提供数据库服务。整个集群称为非共享数据库集群,非

    2025年7月12日
    7
  • Java中的map集合顺序如何与添加顺序一样

    Java中的map集合顺序如何与添加顺序一样一般使用map用的最多的就是hashmap,但是hashmap里面的元素是不按添加顺序的,那么除了使用hashmap外,还有什么map接口的实现类可以用呢?这里有2个,treeMap和linkedHashMap,但是,要达到我们的要求:按添加顺序保存元素的,就只有LinkedHashMap。下面看运行的代码。packagecom.lxk.collectionTest;impor…

    2022年5月30日
    109

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号