Python中TensorFlow的TFRecord和QueueRunner简介与使用
如何将数据集转换为 TensorFlow 的 TFRecord 格式呢?为什么要转换为 TFRecord 格式?如何使用文件队列?如何生成 Batches ?
Python中TensorFlow的TFRecord和QueueRunner简介与使用
原来那个 tf.contrib.data 在 1.4 并入 tf 变成 tf.data 的数据预处理模块楼主用过吗?好多地方都推荐用这个
TFRecord是TensorFlow的一种二进制数据格式,用于高效存储和读取大规模数据集。它把数据序列化成tf.train.Example Protocol Buffers格式,能有效减少I/O开销,特别适合处理图像、文本等非结构化数据。
QueueRunner是TensorFlow早期版本(1.x)中用于异步数据读取和预取的核心机制。它管理着一组队列(如tf.RandomShuffleQueue),并通过启动多个线程来执行入队操作,确保训练时数据能持续供给,避免GPU等待数据。
下面是一个完整的示例,展示如何创建TFRecord文件,并使用QueueRunner配合tf.train.string_input_producer进行读取和训练:
import tensorflow as tf
import numpy as np
# 1. 创建示例TFRecord文件
def write_tfrecord(filename):
writer = tf.python_io.TFRecordWriter(filename)
for i in range(10):
# 创建示例数据
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[i])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(
value=[np.random.randn(28, 28, 3).astype(np.float32).tobytes()]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
# 2. 读取和解析TFRecord
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image'], tf.float32)
image = tf.reshape(image, [28, 28, 3])
label = tf.cast(features['label'], tf.int32)
return image, label
# 3. 使用QueueRunner构建输入管道
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
image, label = read_and_decode(filename_queue)
# 使用tf.train.shuffle_batch进行批处理
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return image_batch, label_batch
# 主程序
if __name__ == "__main__":
# 写入示例文件
write_tfrecord('data.tfrecord')
# 构建计算图
image_batch, label_batch = input_pipeline(['data.tfrecord'], batch_size=4)
# 定义简单模型
logits = tf.layers.dense(tf.layers.flatten(image_batch), units=10)
loss = tf.losses.sparse_softmax_cross_entropy(labels=label_batch, logits=logits)
train_op = tf.train.AdamOptimizer(0.001).minimize(loss)
# 训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # 用于string_input_producer的epoch计数
# 启动QueueRunner
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(100):
_, loss_val = sess.run([train_op, loss])
if step % 20 == 0:
print(f'Step {step}, Loss: {loss_val}')
except tf.errors.OutOfRangeError:
print('训练完成')
finally:
coord.request_stop()
coord.join(threads)
关键点说明:
- TFRecord创建:将每个样本包装成
tf.train.Example,序列化后写入文件。 - 数据解析:使用
tf.parse_single_example解析TFRecord,恢复原始数据格式。 - QueueRunner流程:
tf.train.string_input_producer创建文件名队列tf.train.shuffle_batch创建样本队列- 必须启动
tf.train.start_queue_runners才能让数据流动
- 协调器:
tf.train.Coordinator管理线程生命周期,确保程序正确退出。
注意:在TensorFlow 2.x中,推荐使用tf.data API替代QueueRunner,它提供了更简洁高效的数据管道。但理解QueueRunner有助于处理遗留代码和深入理解数据流机制。
一句话建议:对于新项目,优先使用tf.data API。
这种 feed data 的方法虽然性能最好,但是很不灵活,想要在 epoch 间切换到另一个 dataset 需要用 tf.where 之类的图内条件切换。这也是 TensorFlow 静态图的缺点之一。
TensorFlow 动态图机制 Eager Execution,10.31 号出的,你怎么看?
挺好的啊,给 TensorFlow 使用者提供了另一种选择。当然我觉得对于动态图有需求的早就转到 PyTorch 了。
我使用过,无非就是流化输入吧。有相应的 batch 接口啊,还可以 shuffle,可以参考源码里的 cifar10 的示例。

