TFRecord格式介绍

    xiaoxiao2023-10-15  21

    TFRecord格式

    TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的,TFRecord格式是一种二进制文件,它能够更好的利用内存,更方便复制和移动,并且不需要单独的标签文件;我们可以写一段代码获取你的数据,然后将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecord文件中去

    从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

     

    将MNIST数据集中所有的训练数据存储到TFRecord文件中

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np # 生成整数型的属性 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的属性 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) mnist = input_data.read_data_sets('/path/to/MNIST_DATA/', dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels # 训练数据的图像分辨率,这里可以作为一个属性保存在TFRecord中 pixels = images.shape[1] num_examples = mnist.train.num_examples # 输出TFRecord的地址 file_name = '/path/to/output.tfrecords' # 创建一个writer来写TFRecord文件 writer = tf.python_io.TFRecordWriter(file_name) # writer = tf.python_io.TFRecordWriter(file_name) for index in range(num_examples): # 将图像转换为一个字符串 image_raw = images[index].tostring() # 讲一个样例转换为Example Protocol Buffer, 并将所有的信息写入这个数据结构 example = tf.train.Example(features=tf.train.Features(feature={ 'pixels': _int64_feature(pixels), 'labels': _int64_feature(np.argmax(labels[index])), 'image_raw':_bytes_feature(image_raw) })) # 将一个Example写入TFRecord文件 writer.write(example.SerializeToString()) print('写入成功') writer.close()

    输出结果:

    Extracting /path/to/MNIST_DATA/train-images-idx3-ubyte.gz Extracting /path/to/MNIST_DATA/train-labels-idx1-ubyte.gz Extracting /path/to/MNIST_DATA/t10k-images-idx3-ubyte.gz Extracting /path/to/MNIST_DATA/t10k-labels-idx1-ubyte.gz 写入成功

    读取TFRecord文件中的数据

    import tensorflow as tf # 创建一个reader来读取TFRecord文件中的样例 reader = tf.TFRecordReader() # 创建一个队列来维护输入文件列表 file_queue = tf.train.string_input_producer(['/path/to/output.tfrecords']) # 从文件中读出一个样例,也可以使用read_up_to函数一次性读取多个样例 _, serialized_example = reader.read(file_queue) # 解读入的一个样例,如果需要解析多个样例,使用parse_example函数 features = tf.parse_single_example(serialized_example, features={ # tensorflow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature, # 这种方法解析到的是一个Tensor,另一种方法是tf.VarLenFeature,这种方法得到的解析结果 # 为SparseTensor, 用于处理稀疏数据;这里解析数据的格式要和上面程序写入数据的格式一致 'image_raw':tf.FixedLenFeature([], tf.string), 'pixels':tf.FixedLenFeature([], tf.int64), 'labels':tf.FixedLenFeature([], tf.int64), }) # tf.decode_raw可以将字符串解析为图像对应的像素数组 image = tf.decode_raw(features['image_raw'], tf.uint8) label = tf.cast(features['labels'], tf.int32) pixels = tf.cast(features['pixels'], tf.int32) sess = tf.Session() # 启动多线程处理输入数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print('------------') # 每次运行可以读取TFRecord文件中的一个样例,当所有的样例都读完的时候,在此样例中程序会重头读取 # 读取前十个数据 for i in range(10): print(sess.run([image, label, pixels]))

    运行结果:

    输出mnist数据集前十个像素数组和对应的label、pixels

    最新回复(0)