TFRecord 和 tf.Example
写tfrecord文件
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) mnist = input_data.read_data_sets('./data', dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels size = images.shape[1] num_examples = mnist.train.num_examples # 输出TFRecord文件的地址 filename = './output.tfrecord' # 创建writer来写tfrecords文件 writer = tf.io.TFRecordWriter(filename) for i in range(num_examples): # 将图像矩阵转换为一个字符串 image_raw = images[i].tostring() #将一个样例转换为Example Protocol Buffer,并将所有的信息写入这个数据结构 example = tf.train.Example(features = tf.train.Features(feature={
'size': int64_feature(size), 'label': int64_feature(np.argmax(labels[i])), 'image_raw': bytes_feature(image_raw) })) #将一个Example写入TFRecord文件 writer.write(example.SerializeToString()) writer.close()
读取tfrecord文件
import tensorflow as tf # 创建一个reader来读取TFRecord文件中的样例 reader = tf.TFRecordReader() # 创建一个队列来维护输入文件列表 filename_queue = tf.train.string_input_producer(['./output.tfrecord']) # 从文件中读取一个样例,也可以使用read_up_to函数一次性读取多个样例 _, serialized_example = reader.read(filename_queue) #解析读入的一个样例,如果需要解析多个样例,可以使用parse_example函数 features = tf.parse_single_example( serialized_example, features={
'image_raw': tf.FixedLenFeature([], tf.string), 'size': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(features['image_raw'], tf.uint8) label = tf.cast(features['label'], tf.int32) size = tf.cast(features['size'], tf.int32) # 启动多线程处理数据 coord = tf.train.Coordinator() with tf.Session() as sess: threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 每次运行可以读取TFRecord文件中的一个样例,当所有样例都读完之后,在此示例中程序会再重头读取 for i in range(10): print(sess.run([image,label,size]))
tf.Example
import tensorflow as tf import numpy as np # The following functions can be used to convert a value to a type compatible with tf.Example. def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float(float32) / double(float64).""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def serialize_example(feature0, feature1, feature2, feature3): """ Creates a tf.Example message ready to be written to a file. """ # Create a dictionary mapping the feature name to the tf.Example-compatible data type. feature = {
'feature0': _int64_feature(feature0), 'feature1': _int64_feature(feature1), 'feature2': _bytes_feature(feature2), 'feature3': _float_feature(feature3), } # Create a Features message using tf.train.Example. example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString() print(_bytes_feature(b'test_string')) print(_bytes_feature(u'test_bytes'.encode('utf-8'))) print(_float_feature(np.exp(1))) print(_int64_feature(True)) print(_int64_feature(1)) # 输出: ''' bytes_list { value: "test_string" } bytes_list { value: "test_bytes" } float_list { value: 2.06445 } int64_list { value: 1 } int64_list { value: 1 } ''' serialized_example = serialize_example(False, 4, b'goat', 0.9876) print(serialized_example) # b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?' example_proto = tf.train.Example.FromString(serialized_example) print(example_proto) # 输出: ''' features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 4 } } } feature { key: "feature2" value { bytes_list { value: "goat" } } } feature { key: "feature3" value { float_list { value: 0.48621 } } } } '''
本示例中处理的是单个整型、浮点类型、字节类型,因此value=[value]中对value加了[]使其具有可迭代性,如果需要存储的数据本身就有可迭代性则不能再加[],例如如果是要存储[1.1,1.2,1.3],则对应的函数_float_feature应该写成:
def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value))
完整示例
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # The following functions can be used to convert a value to a type compatible with tf.Example. def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float(float32) / double(float64).""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # Create a dictionary with features that may be relevant. # 高版本的TensorFlow支持解码后直接获取shape def image_example(img_raw, label): img_tensor = tf.image.decode_jpeg(img_raw) image_shape = img_tensor.shape feature = {
'height': _int64_feature(image_shape[0]), 'width': _int64_feature(image_shape[1]), 'depth': _int64_feature(image_shape[2]), 'label': _int64_feature(label), 'image_raw': _bytes_feature(img_raw), } return tf.train.Example(features=tf.train.Features(feature=feature)) # Create a dictionary with features that may be relevant. # 低版本的TensorFlow不支持解码后直接获取shape,转成numpy.ndarray后在获取 def image_example_sess(img_raw, label, sess): img_tensor = tf.image.decode_jpeg(img_raw) with sess.as_default(): img_data = img_tensor.eval() print(type(img_data)) image_shape = img_data.shape feature = {
'height': _int64_feature(image_shape[0]), 'width': _int64_feature(image_shape[1]), 'depth': _int64_feature(image_shape[2]), 'label': _int64_feature(label), 'image_raw': _bytes_feature(img_raw), } return tf.train.Example(features=tf.train.Features(feature=feature)) ''' img_raw = tf.gfile.FastGFile('test1.jpg', 'rb').read() label = 0 with tf.Session() as sess: print(image_example_sess(img_raw, label, sess)) ''' # 写入 TFRecord 文件 # Write the raw image files to `images.tfrecords`. # First, process the two images into `tf.Example` messages. # Then, write to a `.tfrecords` file. image_labels = {
'test1.jpg' : 0, 'test2.jpg': 1} record_file = 'images.tfrecords' with tf.Session() as sess: with tf.io.TFRecordWriter(record_file) as writer: for filename, label in image_labels.items(): img_raw = tf.gfile.FastGFile(filename, 'rb').read() tf_example = image_example_sess(img_raw, label, sess) writer.write(tf_example.SerializeToString()) # 读取 TFRecord 文件 input_files = ['images.tfrecords'] # 可以有多个文件 raw_image_dataset = tf.data.TFRecordDataset(input_files) def _parse_image_function(example_proto): # Create a dictionary describing the features. image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64), # height,width,depth只有一个数字,因此[]中可以不写 'width': tf.io.FixedLenFeature([], tf.int64), 'depth': tf.io.FixedLenFeature([], tf.int64), # 此处label只有1个数字,[]中可以不写,但如果是检测标签会有4个数字(和写tfrecord时一致),[]中就必须写4了,否则无法解析(报错:Can't parse serialized Example.) 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string), } # Parse the input tf.Example proto using the dictionary above. return tf.io.parse_single_example(example_proto, image_feature_description) parsed_image_dataset = raw_image_dataset.map(_parse_image_function) iterator = parsed_image_dataset.make_one_shot_iterator() feature_dict = iterator.get_next() with tf.Session() as sess: for i in range(len(image_labels)): feature_dict_val = sess.run(feature_dict) print('height: ', feature_dict_val['height']) print('width: ', feature_dict_val['width']) print('depth: ', feature_dict_val['depth']) print('label: ', feature_dict_val['label']) img = tf.io.decode_image(feature_dict_val['image_raw']).eval() plt.imshow(img) plt.show() # 读取 TFRecord 文件,文件路径由placeholder提供 input_files = tf.placeholder(tf.string) dataset = tf.data.TFRecordDataset(input_files) dataset = dataset.map(_parse_image_function) # 定义遍历dataset的initializable_iterator() iterator = dataset.make_initializable_iterator() feature_dict = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer,feed_dict={
input_files : ['images.tfrecords', 'images.tfrecords']}) # 遍历所有数据一个epoch,遍历结束时抛出OutOfRangeError,因为在动态指定输入数据时不同数据来源的数据量大小未知, # 该方法使得不必提前知道数据量的确切大小 while True: try: feature_dict_val = sess.run(feature_dict) print('height: ', feature_dict_val['height']) print('width: ', feature_dict_val['width']) print('depth: ', feature_dict_val['depth']) print('label: ', feature_dict_val['label']) img = tf.io.decode_image(feature_dict_val['image_raw']).eval() plt.imshow(img) plt.show() except tf.errors.OutOfRangeError: break input_files = ['images.tfrecords'] # 可以有多个文件 dataset = tf.data.TFRecordDataset(input_files) dataset = dataset.map(_parse_image_function).shuffle(10).batch(10) dataset = dataset.repeat(5) iterator = dataset.make_one_shot_iterator() feature_dict = iterator.get_next() with tf.Session() as sess: while True: try: fig = plt.figure() ax1 = fig.add_subplot(251) ax2 = fig.add_subplot(252) ax3 = fig.add_subplot(253) ax4 = fig.add_subplot(254) ax5 = fig.add_subplot(255) ax6 = fig.add_subplot(256) ax7 = fig.add_subplot(257) ax8 = fig.add_subplot(258) ax9 = fig.add_subplot(259) ax10 = fig.add_subplot(2,5,10) feature_dict_val = sess.run(feature_dict) print('height: ', feature_dict_val['height']) print('width: ', feature_dict_val['width']) print('depth: ', feature_dict_val['depth']) print('label: ', feature_dict_val['label']) img1 = tf.io.decode_image(feature_dict_val['image_raw'][0]).eval() img2 = tf.io.decode_image(feature_dict_val['image_raw'][1]).eval() img3 = tf.io.decode_image(feature_dict_val['image_raw'][2]).eval() img4 = tf.io.decode_image(feature_dict_val['image_raw'][3]).eval() img5 = tf.io.decode_image(feature_dict_val['image_raw'][4]).eval() img6 = tf.io.decode_image(feature_dict_val['image_raw'][5]).eval() img7 = tf.io.decode_image(feature_dict_val['image_raw'][6]).eval() img8 = tf.io.decode_image(feature_dict_val['image_raw'][7]).eval() img9 = tf.io.decode_image(feature_dict_val['image_raw'][8]).eval() img10 = tf.io.decode_image(feature_dict_val['image_raw'][9]).eval() ax1.imshow(img1) ax2.imshow(img2) ax3.imshow(img3) ax4.imshow(img4) ax5.imshow(img5) ax6.imshow(img6) ax7.imshow(img7) ax8.imshow(img8) ax9.imshow(img9) ax10.imshow(img10) plt.show() except tf.errors.OutOfRangeError: break
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/224498.html原文链接:https://javaforall.net
