生成TFrecord

生成TFrecord需求 将图片文件保存成 Tfrecord 的格式 解决方法 基于 tensorflow cv2 numpy 等库完成该功能 注 改编自网上代码 1 准备要训练的手写识别的图片文件 并按照目录结构存放 见下图示意 2 生成训练图片和标签对应的文本文件 见下图示意 3 编写图片生成 TFrecord 代码 代码见下 importnumpya

需求:将图片文件保存成Tfrecord的格式.

解决方法:基于tensorflow、cv2、numpy等库完成该功能.

注:改编自网上代码 

1)  准备要训练的手写识别的图片文件,并按照目录结构存放。见下图示意:

生成TFrecord

生成TFrecord

2)  生成训练图片和标签对应的文本文件,见下图示意:

生成TFrecord

生成TFrecord

3) 编写图片生成TFrecord代码,代码见下:

import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from io import StringIO,BytesIO


# 将value转化成int64字节属性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 将value转化成bytes属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 从训练文本里读取样本、标签。返回样本、标签列表以及行数
def load_file(examples_list_file):
    # type: (object) -> object
    lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])
    examples = []
    labels = []
    for example, label in lines:
        examples.append(example)
        labels.append(label)
    return np.asarray(examples), np.asarray(labels), len(lines)




def extract_image(filename, resize_height, resize_width):
    image = cv2.imread(filename)
    image = cv2.resize(image, (resize_height, resize_width))
    b, g, r = cv2.split(image)
    rgb_image = cv2.merge([r, g, b])
    rgb_image = rgb_image / 255.
    rgb_image = rgb_image.astype(np.float32)
    return rgb_image


def Image2TFRecord(trainDir,trainLabelFile,tfFile):
    resize_height = 28
    resize_width = 28
    train_file_root = trainDir
    train_file = trainLabelFile
    examples, labels, examples_num = load_file(train_file)
    writer = tf.python_io.TFRecordWriter(tfFile)
    for i, [example, label] in enumerate(zip(examples, labels)):
        print('No.%d' % (i))
        print(examples[i].decode(encoding="utf-8"))
        root = train_file_root + '/' + examples[i].decode(encoding="utf-8")
        print(root)
        image = extract_image(root, resize_height, resize_width)
        a = image.shape
        print(root)
        print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))
        image_raw = image.tostring()  # 将Image转化成字符
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'height': _int64_feature(image.shape[0]),
            'width': _int64_feature(image.shape[1]),
            'depth': _int64_feature(image.shape[2]),
            'label': _int64_feature(label)
        }))
        writer.write(example.SerializeToString())
    writer.close()


if __name__ == '__main__':
    Image2TFRecord('E:/Python/mnist_img_data4','E:/Python/mnist_img_data4/train.txt','E:/Python/mnist_img_output/a44.tfrecords')
    TFrcords2Img('E:/Python/mnist_img_output/a44.tfrecords')


# 延展阅读 https://docs.scipy.org/doc/numpy-1.10.1/user/basics.io.genfromtxt.html
    指定按照定长3来分割
    data = "  1  2  3\n  4  5 67\n890123  4"
    print(np.genfromtxt(StringIO(data), delimiter=3))
    ''' 结果:
[[  1.   2.   3.]
 [  4.   5.  67.]
 [890. 123.   4.]]


    '''
     指定按照3列来分割,第一列长度是4,第二列长度是3,第三列长度是2
    data = "123456789\n   4  7 9\n   4567 9"
    print(np.genfromtxt(StringIO(data), delimiter=(4, 3, 2)))


    '''结果:
    [[1234.  567.   89.]
 [   4.    7.    9.]
 [   4.  567.    9.]]
    '''


    data = "1, abc , 2\n 3, xxx, 4"
     默认是不自动替换空格的
    print(np.genfromtxt(StringIO(data), delimiter=",", dtype="|S5"))
    print(np.genfromtxt(StringIO(data), delimiter=",", dtype="|S5", autostrip=True))
'''
[[b'1' b' abc ' b' 2']
 [b'3' b' xxx' b' 4']]
 
[[b'1' b'abc' b'2']
 [b'3' b'xxx' b'4']]


'''
data = "1 2 3\n4 5 6"
print(np.genfromtxt(StringIO(data), usecols=(0, -1)))
'''
[[1. 3.]
 [4. 6.]]
'''
print(np.genfromtxt(StringIO(data),names="a, b, c", usecols=("b", "c")))
'''
[(2., 3.) (5., 6.)]
'''
DataType的类型:
生成TFrecord

生成TFrecord

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/207403.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午1:54
下一篇 2026年3月19日 下午1:54


相关推荐

  • java二维数组初始化的三种方式「建议收藏」

    java二维数组初始化的三种方式「建议收藏」有些知识觉得很简单,但其中一些细节性的东西我们未必知道,比如说数组的定义以及初始化的方式。下面主要介绍下二位数组初始化的三种方式1、定义数组的同时使用大括号直接赋值,适合数组元素已知的情况2、定义二维数组的大小,然后分别赋值3、数组第二维的长度可变化//第一种方式:定义的同时大括号直接复制int[][]array1={{1,3,1},{…

    2022年5月9日
    90
  • 离散数学知识点总结

    离散数学知识点总结3 1 高级概念 k core k truss k clique k club p cohesion k edge vertexconnec k shell 同态核 像同态定理 单 甲是乙的一个子群 满 甲的一个商群是乙 非单非满 甲的一个商群是乙的一个子群 双 甲就是乙 2 一阶逻辑基本概念 个体词 常项 变项 约束 自由 假言推理 附加 化简 拒取 假言三段论 析取三段论 构造性二难 破坏性二难 合取引入 2 基本概念 点 边 邻域 前驱 后继 关联边 端点 相邻边 割 桥

    2026年3月19日
    1
  • Maven的GroupID和ArtifactID的含义「建议收藏」

    标签:目的   left   就会   定义   平时   包名   项目   rep   depend   groupID:是项目组织唯一的标识符,实际对应Java的包的结构,是main目录里Java的目录结构。artifactID:是项目的唯一标识符,实际对应项目的名称,就是项目根目录的名称。 1.基础掌握<groupId>com.yucong.commonma…

    2022年4月14日
    232
  • spss 卡方检验,Logistic回归方法「建议收藏」

    spss 卡方检验,Logistic回归方法「建议收藏」案例:新生儿体重较低影响因素1:影响因素分析,求出哪些自变量对因变量发生概率有影响,并计算各自变量对因变量的比数比;2:作为判别分析方法,来估计各种自变量组合条件下因变量各类别的发生概率,从而对结局进行预测,该模型在结果上等价于判别分析;说明:低出生体重标准:新生儿体重<2500克结果变量为是否娩出低出生体重儿,变量名为low,1=低出生体重,0=非低出生体重;考虑的影响因素…

    2022年5月16日
    52
  • 字典序算法详解

    字典序算法详解一 字典序字典序 就是按照字典中出现的先后顺序进行排序 1 单个字符在计算机中 25 个字母以及数字字符 字典排序如下 0 lt 1 lt 2 lt lt 9 lt a lt b lt lt z 比如在 python 中

    2026年3月19日
    2
  • vue + echarts 省份地图 以及打包后地图加载不出来(比较详细)「建议收藏」

    vue + echarts 省份地图 以及打包后地图加载不出来(比较详细)「建议收藏」刚开始地图怎么也出不来,经过解决,是因为echarts.min.js引入位置在index.html中引入需要的js版本按照自己需要的来<scriptsrc=”./static/plugins/echarts-5.1.2/echarts.common.min.js”></script><scriptsrc=”./static/plugins/echarts.min.js”></script>(必须引入,地图才能加载)全局引入im

    2022年10月12日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号