tfrecord文件生成与读取

tfrecord文件生成与读取参考博客 tensorflow TFRecord 文件详解 1 生成 tfrecord 文件代码 1 创建 tfrecord 对象 tf record tf python io TFRecordWrit tf record name tf train Int64List value list data tf train FloatList tf train BytesList tf train Feature int64 list tf train Feature float l

参考博客——tensorflow-TFRecord 文件详解

1. 生成tfrecord文件

在这里插入图片描述
代码

#1.创建tfrecord对象 tf_record=tf.python_io.TFRecordWriter(tf_record_name) tf.train.Int64List(value=list_data) tf.train.FloatList( ) tf.train.BytesList() tf.train.Feature(int64_list=) tf.train.Feature(float_list=tf.train.FloatList()) tf.train.Feature(bytes_list=tf.train.BytesList()) tf.train.Features(feature=dict_data) ut = tf.train.Features(feature={ 
   "suibian": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 4])),"a":tf.train.Feature(float_list=tf.train.FloatList(value=[5., 7.]))}) example=tf.train.Example(features=tf.train.Features(...)) #2. 写入example对象序列化后的结果 tfrecord_writer.write(example.SerializeToString()) 

2. 读取tfrecord文件

从文件读取有 3 大步骤

  1. 生成读取器,不同类型的文件有对应的读取器
  2. 把文件名列表生成队列
  3. 用读取器的 read 方法读取队列中的文件
    在这里插入图片描述
    在这里插入图片描述




3 代码

3.1 dataset_to_tfrecord.py

在这里插入图片描述

import os import xml.etree.ElementTree as ET import tensorflow as tf from dataset_config import DIRECTORY_ANNOTATIONS,DIRECTORY_IMAGES,NUM_IMAGES_TFRECORD,labels_to_class from utils.data_process_util import int64_feature,float_feature,bytes_feature def _convert_to_example(img,img_shape,labels,trunacted,difficult,bndbox_size): '''将一张图片使用example,转换成protobuffer 格式 :param img: :param img_shape: :param labels: :param trunacted: :param difficult: :param bndbox_size: :return: ''' # 为了转换需求,bbox由单个obj的四个位置值, # 转变成四个位置的单独列表 # 即:[[12,120,330,333],[50,60,100,200]]————>[[12,50],[120,60],[330,100],[333,200]] ymin=[] xmin=[] ymax=[] xmax=[] for b in bndbox_size: ymin.append(b[0]) xmin.append(b[1]) ymax.append(b[2]) xmax.append(b[3]) img_format = b'JPEG' print(type(labels)) for i,label in enumerate(labels): labels[i]=labels_to_class[label] print('trunacted:',trunacted,type(trunacted),len(trunacted)) example = tf.train.Example(features=tf.train.Features(feature={ 
    'image/height':int64_feature(img_shape[0]), 'image/width':int64_feature(img_shape[1]), 'image/channels':int64_feature(img_shape[2]), 'image/shape':int64_feature(img_shape), 'image/object/bbox/xmin':float_feature(xmin), 'image/object/bbox/ymin':float_feature(ymin), 'image/object/bbox/xmax':float_feature(xmax), 'image/object/bbox/ymax':float_feature(ymax), 'image/object/bbox/label_text':int64_feature(labels), # 'image/object/bbox/trunacted':bytes_feature(trunacted), # 'image/object/bbox/difficult':bytes_feature(difficult), 'image/object/bbox/format':bytes_feature(img_format), 'image/object/bbox/data':bytes_feature(img)# 读取的图像值 })) print(img_format) return example def _process_image(dataset_dir,img_name): ''' 读取图像和xml文件 :param dataset_dir: :param img_name: :return: ''' #1.读取图像 #图像路径 img_path = os.path.join(dataset_dir,DIRECTORY_IMAGES,img_name+'.jpg') img = tf.gfile.FastGFile(img_path,'rb').read()#tensorflow读取图像 #2.读取xml #xml路径 xml_path =os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS,img_name+'.xml') tree = ET.parse(xml_path) root = tree.getroot()#获取根节点,'annotation'标签 # 2.1获取图像尺寸信息 size = root.find('size') img_shape=[ int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text) ] #2.2 获取bounding box 相关信息 # bounding box可能有多个,用多个列表存储相关信息。 labels = [] trunacted=[] difficult = [] bndbox_sizes=[] bboxes = root.findall('object') for obj in bboxes: label = obj.find('name').text if obj.find('trunacted'): trunacted.append(obj.find('trunacted').text) else: trunacted.append('0') if obj.find(''): difficult.append(obj.find('difficult').text) else: difficult.append(0) bndbox = obj.find('bndbox') bndbox_size=( float(bndbox.find('ymin').text)/img_shape[0], float(bndbox.find('xmin').text)/img_shape[1], float(bndbox.find('ymax').text)/img_shape[0], float(bndbox.find('xmax').text)/img_shape[1] ) labels.append(label) trunacted.append(trunacted) difficult.append(difficult) bndbox_sizes.append(bndbox_size) return img,img_shape,labels,trunacted,difficult,bndbox_sizes def _add_to_tfrecord(dataset_dir,img_name,tfrecord_writer): ''' 读取图片和xml文件,保存成一个Example :param dataset_dir:根目录 :param img_name:图像名称 :param tfrecord_writer: :return: ''' #1.读取图片内容及相应的xml文件 img, img_shape, labels, trunacted, difficult, bndbox_size=_process_image(dataset_dir,img_name) # return img,img_shape,labels,trunacted,difficult,bndbox_size #2.读取的内容封装成Example, example = _convert_to_example(img, img_shape, labels, trunacted, difficult, bndbox_size) #3.Example序列化结果写入指定tfrecord文件 tfrecord_writer.write(example.SerializeToString()) def _get_output_tfrecord_name(output_dir,name,fdx): """ :param output_dir: :param name: :param fdx:第几个tfrecord文件 :return: """ return os.path.join(output_dir,name,'%06d'%fdx+'.tfrecord') def read_tfrecord(): slim = tf.contrib.slim dataset = slim.dataset #第一个参数,文件路径 file_pattern = os.path.join('tf_records\data','*.record') #第二个参数 reader = tf.TFRecordReader # file_pattern = '%s-* ' # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord” # file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # dataset_dir即前面保存的tfrecord文件的路径 # 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式, # 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png' keys_to_features = { 
    'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''), 'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)} # 将反序列化的数据重组为更适合网络读入的格式 items_to_handlers = { 
    'image': slim.tfexample_decoder.Image( image_key='image/object/bbox/data', format_key='image/object/bbox/format', channels=3), # 'image_name': tfexample_decoder.Tensor('image/filename'), 'height': slim.tfexample_decoder.Tensor('image/height'), 'width': slim.tfexample_decoder.Tensor('image/width'), # 'labels_class': tfexample_decoder.Image( # image_key='image/segmentation/class/encoded', # format_key='image/segmentation/class/format', # channels=1) } # 解码器进行解码,定义一个解码器对象,保存到dataset中 # 第三个参数decoder decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) # 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息 dataset = dataset.Dataset( data_sources=file_pattern, # tfrecord路径 reader=tf.TFRecordReader, # 读取tfrecord文件的方式 decoder=decoder, # 解码tfrecord文件的方式 num_samples=1464, # PASCAL-VOC2012数据集训练样本数 items_to_descriptions={ 
    # 样本集图像和标签描述 'image': 'A color image of varying height and width.', 'labels_class': ('A semantic segmentation label whose size matches image.' 'Its values range from 0 (background) to num_classes.')}, num_classes = 3, # 数据集包含类别数(20个前景类别和1个背景类别) multi_label = True) # 多标签(具体我也不太清楚) dataset_data_provider = slim.dataset_data_provider prefetch_queue = slim.prefetch_queue # 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。 data_provider = dataset_data_provider.DatasetDataProvider( dataset, num_readers=1, num_epochs=None, shuffle=True) # 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名 image, height, width = data_provider.get(['image', 'height', 'width']) # image_name, = data_provider.get(['image_name']) # label = data_provider.get(['label']) # 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程 return image, height, width def run(dataset_dir,output_dir,name='data'): """ 运行转换代码逻辑。 存入多个tfrecord文件,每个文件固定N个样本 :param dataset_dir:数据集目录,包含annotations,jpeg文件夹 :param output_dir:tfrecords存储目录 :param name:数据集名字,指定名字以及train or test :return: """ # 1. 判断数据集目录是否存在,创建一个目录 if tf.gfile.Exists(dataset_dir): tf.gfile.MakeDirs(dataset_dir) # 输出路径需要已存在 # if tf.gfile.Exists(output_dir): # tf.gfile.MakeDirs(output_dir) # 2. 读取某个文件夹下的所有文件名字列表 dir_path = os.path.join(dataset_dir,DIRECTORY_ANNOTATIONS) files_path = sorted(os.listdir(dir_path)) print(files_path) # 3. 循环名字列表, # 每200(NUM_IMAGES_TFRECORD)个图片及xml文件存储到一个tfrecord文件中 num = len(files_path) i = 0 fdx = 0 while i < num: tf_record_name = _get_output_tfrecord_name(output_dir,name,fdx) with tf.python_io.TFRecordWriter(tf_record_name) as tf_record_writer: j = 0 while i<num and j < NUM_IMAGES_TFRECORD: xml_path = files_path[i] img_name = xml_path.split('.')[0] #每个图像构建一个Example,保存到tf_record_name中 _add_to_tfrecord(dataset_dir,img_name,tf_record_writer) j += 1 i += 1 fdx += 1 print('fdx',fdx) print('数据集%s转换成功'%(dataset_dir)) 

3.2 tfrecord文件读取

在这里插入图片描述

def read_tfrecord(): slim = tf.contrib.slim dataset = slim.dataset #第一个参数,文件路径 file_pattern = os.path.join('tf_records\data','*.tfrecord') #第二个参数 reader = tf.TFRecordReader # file_pattern = '%s-* ' # 前面保存的tfrecord文件的文件名类似于“train-00001-of-00004.tfrecord” # file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # dataset_dir即前面保存的tfrecord文件的路径 # 使用slim中的函数tf.FixedLenFeature将tfrecord的example反序列化成存储之前的格式, # 字符串格式的用''表示,整型格式的用0表示,其他确定的信息还原为原来的形式,如'jpeg','png' keys_to_features = { 
    'image/object/bbox/data': tf.FixedLenFeature((), tf.string, default_value=''), 'image/object/bbox/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/object/bbox/label_text': tf.FixedLenFeature((), tf.int64, default_value=0)} # 将反序列化的数据重组为更适合网络读入的格式 items_to_handlers = { 
    'image': slim.tfexample_decoder.Image( image_key='image/object/bbox/data', format_key='image/object/bbox/format', channels=3), # 'image_name': tfexample_decoder.Tensor('image/filename'), 'height': slim.tfexample_decoder.Tensor('image/height'), 'width': slim.tfexample_decoder.Tensor('image/width'), # 'labels_class': tfexample_decoder.Image( # image_key='image/segmentation/class/encoded', # format_key='image/segmentation/class/format', # channels=1) } # 解码器进行解码,定义一个解码器对象,保存到dataset中 # 第三个参数decoder decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) # 返回由tfrecord信息所得到的数据集dataset,dataset对象定义了数据集的文件位置,解码方式等元信息 dataset = dataset.Dataset( data_sources=file_pattern, # tfrecord路径 reader=tf.TFRecordReader, # 读取tfrecord文件的方式 decoder=decoder, # 解码tfrecord文件的方式 num_samples=1464, # PASCAL-VOC2012数据集训练样本数 items_to_descriptions={ 
    # 样本集图像和标签描述 'image': 'A color image of varying height and width.', 'labels_class': ('A semantic segmentation label whose size matches image.' 'Its values range from 0 (background) to num_classes.')}, num_classes = 3, # 数据集包含类别数(20个前景类别和1个背景类别) multi_label = True) # 多标签(具体我也不太清楚) dataset_data_provider = slim.dataset_data_provider prefetch_queue = slim.prefetch_queue # 创建一个DatasetDataProvider类的对象data_provider,根据dataset和其他的一些已知信息读取数据。 data_provider = dataset_data_provider.DatasetDataProvider( dataset, num_readers=1, num_epochs=None, shuffle=True) # 通过调用data_provider对象的get实例函数能够根据data_provider中给出的信息解读tfrecord文件,生成图像和标签和图像文件名 image, height, width = data_provider.get(['image', 'height', 'width']) # image_name, = data_provider.get(['image_name']) # label = data_provider.get(['label']) # 图像预处理过程,这里具体的处理过程与本文主题无关,因此省略具体的处理过程 return image, height, width 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/219897.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 下午9:38
下一篇 2026年3月17日 下午9:38


相关推荐

  • Android ROM 制作教程

    Android ROM 制作教程

    2021年12月13日
    64
  • 什么是跨域?跨域解决方法

    什么是跨域?跨域解决方法一、为什么会出现跨域问题出于浏览器的同源策略限制。同源策略(Sameoriginpolicy)是一种约定,它是浏览器最核心也最基本的安全功能,如果缺少了同源策略,则浏览器的正常功能可能都会受到影响。可以说Web是构建在同源策略基础之上的,浏览器只是针对同源策略的一种实现。同源策略会阻止一个域的javascript脚本和另外一个域的内容进行交互。所谓同源(即指在同一个域)就是两个页面具有相同的协…

    2022年4月28日
    55
  • 八数码问题求解「建议收藏」

    八数码问题求解「建议收藏」(一)问题描述在一个3*3的方棋盘上放置着1,2,3,4,5,6,7,8八个数码,每个数码占一格,且有一个空格。这些数码可以在棋盘上移动,其移动规则是:与空格相邻的数码方格可以移入空格。现在的问题是:对于指定的初始棋局和目标棋局,给出数码的移动序列。该问题称八数码难题或者重排九宫问题。(二)问题分析八数码问题是个典型的状态图搜索问题。搜索方式有两种基本的方式,即树式搜索和线式搜索。搜索策略大体有盲…

    2022年7月26日
    7
  • 指标异动分析「建议收藏」

    指标异动分析「建议收藏」What业务都会面对“为什么涨、为什么降、原因是什么?”,因此日常数据分析80%总是在围绕指标异动做分析,进行原因定位,常见的指标异动分析例如GMV、DAU等为何下降?Why指标异动分析有利于为业务方建立业务抓手,及时定位业务异常原因,进而制定相应的运营调整策略,保障业务正常稳定发展How1、明确异常指标波动标准(净值百分比)业务指标会随着内外部环境变动而不断变化,数据的波动主要体现在变动日期与基准日期的对比(同环比)出现上升或下降。指标波动通常分为周期性波动、突发性波动、持续性波动。**

    2022年6月10日
    53
  • Apache 服务器和Tomcat 服务器的区别

    Apache 服务器和Tomcat 服务器的区别nbsp nbsp nbsp nbsp 最近工作总是接触到 Apache 和 Tomcat 服务器 它们到底有什么区别 还是有点模糊 下面梳理一下 nbsp nbsp nbsp nbsp Apache 是 Web 服务器 静态解析 如 HTML Tomcat 是 Java 应用服务器 动态解析 如 JSP 请参考 web 服务器与应用服务器的区别 nbsp nbsp nbsp nbsp Tomcat 是一个 Servlet JSP 容器 是 Apache 的扩展 可以独立于 Apache

    2026年3月16日
    2
  • CQRS架构

    CQRS架构命令查询的责任分离 CommandQuery 简称 CQRS 模式是一种架构体系模式 能够使改变模型的状态的命令和模型状态的查询实现分离 这属于 DDD 应用领域的一个模式 为了使得项目逻辑更加清晰 便于对不同部分进行针对性的优化 一 背景问题在以前的管理系统中 命令 Command 通常用来更新数据 操作 DB 和查询 Query 通常使用

    2026年3月19日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号