Python中关于深度学习的输出文件如何生成和处理?
想问一下,我用 resnet 网络训练的 cifar-10 数据集输出的模型里面没有 pbtxt 文件,只有 meta,index,data 和 checkpoint 文件。cifar-10 应该也是图像分类数据集吧,那要怎么得到它分类的 pbtxt 文件啊。
item {
id: 1
name: ‘Cat’
}
就是一个如上所示分类数据的文件。
Python中关于深度学习的输出文件如何生成和处理?
输出里没有 pbtxt 吧,这是分类名的映射,输入的时候填的,它的路径应该在你的 pipeline.config 文件里面有写
在Python深度学习中,生成和处理输出文件主要涉及模型训练结果、预测输出和中间数据的保存与读取。核心方法是使用标准库(如pickle、json)和深度学习框架(如PyTorch的torch.save、TensorFlow/Keras的model.save)进行序列化存储。
1. 保存模型权重和架构
# PyTorch示例
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = SimpleNet()
# 保存完整模型(架构+权重)
torch.save(model, 'model.pth')
# 仅保存权重
torch.save(model.state_dict(), 'weights.pth')
# 加载完整模型
loaded_model = torch.load('model.pth')
# 加载权重到现有架构
model.load_state_dict(torch.load('weights.pth'))
2. 保存训练指标和预测结果
import json
import numpy as np
import pandas as pd
# 保存训练历史
history = {'loss': [0.5, 0.3, 0.2], 'accuracy': [0.8, 0.9, 0.95]}
with open('history.json', 'w') as f:
json.dump(history, f)
# 保存预测结果
predictions = np.array([[0.1, 0.9], [0.8, 0.2]])
# 保存为numpy格式
np.save('predictions.npy', predictions)
# 保存为CSV
pd.DataFrame(predictions).to_csv('predictions.csv', index=False)
3. 处理图像/文本输出
from PIL import Image
import cv2
# 保存生成图像
def save_generated_images(images, path):
for i, img in enumerate(images):
# 假设img是0-1范围的numpy数组
img = (img * 255).astype(np.uint8)
Image.fromarray(img).save(f'{path}/image_{i}.png')
# 保存文本生成结果
def save_text_results(texts, filename):
with open(filename, 'w', encoding='utf-8') as f:
for i, text in enumerate(texts):
f.write(f"Result {i}: {text}\n")
4. 使用HDF5存储大型数据集
import h5py
# 保存特征和标签
with h5py.File('features.h5', 'w') as f:
f.create_dataset('features', data=np.random.randn(1000, 256))
f.create_dataset('labels', data=np.random.randint(0, 10, 1000))
# 读取
with h5py.File('features.h5', 'r') as f:
features = f['features'][:]
labels = f['labels'][:]
关键点总结:
- 模型保存:PyTorch用
.pth文件,TensorFlow用.h5或SavedModel格式 - 数据保存:结构化数据用CSV/JSON,张量数据用NumPy/HDF5
- 文件组织:建议按类型分目录(如
models/、results/、logs/)
根据你的具体任务选择合适的格式和工具。
我没有通过 object-detection 进行训练,没有 pipeline.config 这个文件。其实意思就是分类集合在数据集的某个文件里其实已经有了,因为训练分类的时候肯定会用到的,所以现在就是需要想办法找到并把它导出来吧。
你都不说你用的什么技术栈…… TensorFlow ?
TF 的话你用 code + checkpoint 就能恢复模型啊,什么 pbtxt 都是没必要的。
而且 pipeline.config 是什么鬼,用了这么久 TF 没见过这玩意。
是用的 tensorflow,我不是想恢复模型,我是想用训练出的模型进行图片的分类和检测,但是我看网上给出的方法都需要一个 pbtxt 的分类标签数据集,但是我训练出的模型只有 meta,index,ckpt 这些,也固化出了 frozen-graph-model.pb 文件,但是没找到只用这些文件就能检测识别的方法。
.pb 和 .pbtxt 是模型文件的两种表现方式,.pb 是在你使用 binary 方式保存的模型文件,.pbtxt 是在你保存模型的时候使用了 as_text = true 这一个控制条件才会生成。
如果仅有 .pb 文件怎么加载:
https://gist.github.com/Quorafind/b06d3d15b6636dc57e5216349635813c
那么怎么使用:
https://gist.github.com/Quorafind/a0d07b700b2fa2e91e487c074f45cc2d
参考:
https://www.tensorflow.org/guide/extend/model_files#text_or_binary
https://www.tensorflow.org/api_docs/python/tf/saved_model/Builder#save
https://stackoverflow.com/questions/51278213/what-is-the-use-of-modelpb-file-in-tensor-flow-and-how-does-it-works
https://blog.csdn.net/zryowen123/article/details/79889988
thanks 我去看一下
我说的恢复模型就是载入模型。。。。
你完全可以把你训练的时候建图用的代码再跑一遍,然后通过 whatever 方法把 variable restore 到 session 里面,然后你想干啥都行了。
顺便说一句,TF 2.0 和 PyTorch 都是动态图了,而动态图用代码 + restore variable 是最自然的方法。。。
每个人的知识边界不同,没见过很正常。我的知识可能没你全面,但我课上用过的那个确实就是 pipeline.config 文件
https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/faster_rcnn_resnet101_pets.config
方便留个联系方式么?
你好再问一个问题,刚才用 object-detection 试了一下,在运行 model_main.py 主程序的时候,出现了下面的问题:
WARNING:tensorflow:Forced number of epochs for all eval validations to be 1.
WARNING:tensorflow:Expected number of evaluation epochs is 1, but instead encountered eval_on_train_input_config.num_epochs = 0. Overwriting num_epochs to 1.
WARNING:tensorflow:Estimator’s model_fn (<function create_model_fn.<locals>.model_fn at 0x000001B0B6282F28>) includes params argument, but params are not passed to Estimator.
出现了这三个警告后程序就自己停止了,我看了一下对应的 config 文件里面也没有 num_epochs 这个参数,我就自己在最后加了一个 num_epochs: 1,运行了还是这个问题,想请问一下这种情况是怎么回事。
Good。事实上我没见过这种 .config
印象中没遇到过这个这种问题,我搜了一下下面这两个 issue 好像跟你遇到的问题差不多?我不太了解具体情况,可能还是得需要你自己搜一下这三个 warning 代表啥
https://github.com/tensorflow/models/issues/5790
https://github.com/kubeflow/examples/issues/277

