Python中TensorFlow的TFRecord和QueueRunner简介与使用

如何将数据集转换为 TensorFlow 的 TFRecord 格式呢?为什么要转换为 TFRecord 格式?如何使用文件队列?如何生成 Batches ?

TensorFlow 的 TFRecord 和 QueueRunner 简介


Python中TensorFlow的TFRecord和QueueRunner简介与使用
7 回复

原来那个 tf.contrib.data 在 1.4 并入 tf 变成 tf.data 的数据预处理模块楼主用过吗?好多地方都推荐用这个


TFRecord是TensorFlow的一种二进制数据格式,用于高效存储和读取大规模数据集。它把数据序列化成tf.train.Example Protocol Buffers格式,能有效减少I/O开销,特别适合处理图像、文本等非结构化数据。

QueueRunner是TensorFlow早期版本(1.x)中用于异步数据读取和预取的核心机制。它管理着一组队列(如tf.RandomShuffleQueue),并通过启动多个线程来执行入队操作,确保训练时数据能持续供给,避免GPU等待数据。

下面是一个完整的示例,展示如何创建TFRecord文件,并使用QueueRunner配合tf.train.string_input_producer进行读取和训练:

import tensorflow as tf
import numpy as np

# 1. 创建示例TFRecord文件
def write_tfrecord(filename):
    writer = tf.python_io.TFRecordWriter(filename)
    for i in range(10):
        # 创建示例数据
        feature = {
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[i])),
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[np.random.randn(28, 28, 3).astype(np.float32).tobytes()]))
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())
    writer.close()

# 2. 读取和解析TFRecord
def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string)
        })
    image = tf.decode_raw(features['image'], tf.float32)
    image = tf.reshape(image, [28, 28, 3])
    label = tf.cast(features['label'], tf.int32)
    return image, label

# 3. 使用QueueRunner构建输入管道
def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(
        filenames, num_epochs=num_epochs, shuffle=True)
    image, label = read_and_decode(filename_queue)
    
    # 使用tf.train.shuffle_batch进行批处理
    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    
    return image_batch, label_batch

# 主程序
if __name__ == "__main__":
    # 写入示例文件
    write_tfrecord('data.tfrecord')
    
    # 构建计算图
    image_batch, label_batch = input_pipeline(['data.tfrecord'], batch_size=4)
    
    # 定义简单模型
    logits = tf.layers.dense(tf.layers.flatten(image_batch), units=10)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=label_batch, logits=logits)
    train_op = tf.train.AdamOptimizer(0.001).minimize(loss)
    
    # 训练
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())  # 用于string_input_producer的epoch计数
        
        # 启动QueueRunner
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
        try:
            for step in range(100):
                _, loss_val = sess.run([train_op, loss])
                if step % 20 == 0:
                    print(f'Step {step}, Loss: {loss_val}')
        except tf.errors.OutOfRangeError:
            print('训练完成')
        finally:
            coord.request_stop()
            coord.join(threads)

关键点说明:

  1. TFRecord创建:将每个样本包装成tf.train.Example,序列化后写入文件。
  2. 数据解析:使用tf.parse_single_example解析TFRecord,恢复原始数据格式。
  3. QueueRunner流程
    • tf.train.string_input_producer创建文件名队列
    • tf.train.shuffle_batch创建样本队列
    • 必须启动tf.train.start_queue_runners才能让数据流动
  4. 协调器tf.train.Coordinator管理线程生命周期,确保程序正确退出。

注意:在TensorFlow 2.x中,推荐使用tf.data API替代QueueRunner,它提供了更简洁高效的数据管道。但理解QueueRunner有助于处理遗留代码和深入理解数据流机制。

一句话建议:对于新项目,优先使用tf.data API。

这种 feed data 的方法虽然性能最好,但是很不灵活,想要在 epoch 间切换到另一个 dataset 需要用 tf.where 之类的图内条件切换。这也是 TensorFlow 静态图的缺点之一。

TensorFlow 动态图机制 Eager Execution,10.31 号出的,你怎么看?

挺好的啊,给 TensorFlow 使用者提供了另一种选择。当然我觉得对于动态图有需求的早就转到 PyTorch 了。

我使用过,无非就是流化输入吧。有相应的 batch 接口啊,还可以 shuffle,可以参考源码里的 cifar10 的示例。

回到顶部