Rust机器学习推理引擎tract-core的使用,高性能ONNX/TensorFlow模型解析与执行
Rust机器学习推理引擎tract-core的使用,高性能ONNX/TensorFlow模型解析与执行
安装
在项目目录中运行以下Cargo命令:
cargo add tract-core
或者在Cargo.toml中添加以下行:
tract-core = "0.21.13"
基本使用示例
下面是一个使用tract-core加载和运行ONNX模型的完整示例:
use tract_onnx::prelude::*;
fn main() -> TractResult<()> {
// 1. 加载ONNX模型
let model = tract_onnx::onnx()
// 从文件加载模型
.model_for_path("path/to/model.onnx")?
// 优化模型
.into_optimized()?
// 转换为可执行模型
.into_runnable()?;
// 2. 准备输入数据
// 假设模型需要一个形状为[1, 3, 224, 224]的f32张量
let input = Tensor::from_shape(
&[1, 3, 224, 224],
&vec![0.0f32; 1 * 3 * 224 * 224],
)?;
// 3. 运行推理
let outputs = model.run(tvec!(input))?;
// 4. 处理输出
let output = outputs[0].to_array_view::<f32>()?;
println!("模型输出: {:?}", output);
Ok(())
}
加载TensorFlow模型示例
use tract_tensorflow::prelude::*;
fn main() -> TractResult<()> {
// 1. 加载TensorFlow模型
let model = tract_tensorflow::tensorflow()
// 从文件加载模型
.model_for_path("path/to/model.pb")?
// 指定输入形状和类型
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 224, 224, 3)))?
// 优化模型
.into_optimized()?
// 转换为可执行模型
.into_runnable()?;
// 2. 准备输入数据
let input = Tensor::from_shape(
&[1, 224, 224, 3],
&vec![0.0f32; 1 * 224 * 224 * 3],
)?;
// 3. 运行推理
let outputs = model.run(tvec!(input))?;
// 4. 处理输出
let output = outputs[0].to_array_view::<f32>()?;
println!("模型输出: {:?}", output);
Ok(())
}
高级功能
动态输入形状
use tract_onnx::prelude::*;
fn main() -> TractResult<()> {
let mut model = tract_onnx::onnx()
.model_for_path("path/to/model.onnx")?
// 声明动态输入形状
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, Symbol::from('H'), Symbol::from('W')))?
.into_optimized()?;
// 指定具体输入尺寸
model.set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224))?;
let runnable = model.into_runnable()?;
// ...运行推理
Ok(())
}
性能优化
use tract_onnx::prelude::*;
fn main() -> TractResult<()> {
let model = tract_onnx::onnx()
.model_for_path("path/to/model.onnx")?
// 启用更多优化
.with_optimization_policy(tract_core::OptimizationPolicy::Max)
// 指定目标架构以获得最佳性能
.with_target(tract_core::Target::Host)
.into_optimized()?
.into_runnable()?;
// ...运行推理
Ok(())
}
完整示例
下面是一个完整的ONNX模型推理示例,包含错误处理和更详细的注释:
use tract_onnx::prelude::*;
fn main() {
match run_inference() {
Ok(_) => println!("推理成功完成"),
Err(e) => println!("推理过程中发生错误: {}", e),
}
}
fn run_inference() -> TractResult<()> {
// 1. 加载并准备模型
let model = tract_onnx::onnx()
.model_for_path("path/to/model.onnx")?
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
.into_optimized()?
.into_runnable()?;
// 2. 创建随机输入数据 (实际应用中应使用真实数据)
let input = Tensor::rand::<f32>(&[1, 3, 224, 224]);
// 3. 运行推理并测量时间
use std::time::Instant;
let start = Instant::now();
let outputs = model.run(tvec!(input))?;
let duration = start.elapsed();
println!("推理耗时: {:?}", duration);
// 4. 处理输出
let output = outputs[0].to_array_view::<f32>()?;
println!("输出形状: {:?}", output.shape());
// 获取预测结果 (假设是分类模型)
if let Some((max_idx, max_val)) = output.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
{
println!("预测类别: {}, 置信度: {}", max_idx, max_val);
}
Ok(())
}
1 回复
Rust机器学习推理引擎tract-core的使用指南
什么是tract-core?
tract-core是一个纯Rust实现的轻量级机器学习推理引擎,专注于高性能的模型解析与执行。它支持ONNX和TensorFlow模型格式,能够在Rust生态系统中提供高效的神经网络推理能力。
主要特性
- 支持ONNX和TensorFlow模型导入
- 无运行时依赖的纯Rust实现
- 高性能的模型执行
- 支持CPU上的多线程执行
- 提供简单易用的API
安装方法
在Cargo.toml中添加依赖:
[dependencies]
tract-core = "0.18"
tract-onnx = "0.18" # 如果需要ONNX支持
tract-tensorflow = "0.18" # 如果需要TensorFlow支持
基本使用方法
加载并运行ONNX模型
use tract_onnx::prelude::*;
fn main() -> TractResult<()> {
// 加载ONNX模型
let model = tract_onnx::onnx()
.model_for_path("path/to/model.onnx")?
.into_optimized()?
.into_runnable()?;
// 准备输入数据 (示例使用随机数据)
let input = tensor1(&[0f32, 1.0, 2.0, 3.0, 4.0]).into_shape(&[1, 5])?;
// 运行模型
let outputs = model.run(tvec!(input))?;
// 处理输出
println!("模型输出: {:?}", outputs[0]);
Ok(())
}
加载并运行TensorFlow模型
use tract_tensorflow::prelude::*;
fn main() -> TractResult<()> {
// 加载TensorFlow模型
let model = tract_tensorflow::tensorflow()
.model_for_path("path/to/model.pb")?
.into_optimized()?
.into_runnable()?;
// 准备输入数据
let input = tensor1(&[1f32, 2.0, 3.0]).into_shape(&[1, 3])?;
// 运行模型
let outputs = model.run(tvec!(input))?;
// 处理输出
println!("模型输出: {:?}", outputs[0]);
Ok(())
}
高级用法
指定输入输出节点名称
let model = tract_onnx::onnx()
.model_for_path("model.onnx")?
.with_input_names(vec!["input_name"])?
.with_output_names(vec!["output_name"])?
.into_optimized()?
.into_runnable()?;
使用多线程执行
let mut model = tract_onnx::onnx()
.model_for_path("model.onnx")?
.into_optimized()?
.into_runnable()?;
// 设置线程数
model.set_concurrency(4);
性能优化建议
- 使用
into_optimized()
对模型进行优化 - 对于重复执行,复用
RunnableModel
实例 - 合理设置并发线程数
- 考虑使用定点的量化模型减少计算量
常见问题解决
- 模型加载失败:检查模型路径是否正确,确保模型格式与使用的加载器匹配
- 输入形状不匹配:使用
input.into_shape()
调整输入张量的形状 - 性能问题:尝试使用
into_optimized()
优化模型,并调整并发设置
实际应用示例:图像分类
use tract_onnx::prelude::*;
use image::GenericImageView;
fn classify_image(model_path: &str, image_path: &str) -> TractResult<()> {
// 加载模型
let model = tract_onnx::onnx()
.model_for_path(model_path)?
.into_optimized()?
.into_runnable()?;
// 加载并预处理图像
let img = image::open(image_path)?;
let resized = img.resize_exact(224, 224, image::imageops::FilterType::Triangle);
let image = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32);
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(pixel[c] as f32 / 255.0 - mean) / std
}).into();
// 运行推理
let outputs = model.run(tvec!(image))?;
let best = outputs[0]
.to_array_view::<f32>()?
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap());
println!("预测结果: 类别 {} 置信度 {}", best.unwrap().0, best.unwrap().1);
Ok(())
}
完整示例代码
下面是一个完整的图像分类示例,包含模型加载、图像预处理和推理全过程:
use tract_onnx::prelude::*;
use image::{GenericImageView, ImageBuffer, Rgb};
use std::path::Path;
fn main() -> TractResult<()> {
// 1. 加载ONNX模型
let model_path = "resnet18.onnx";
let image_path = "cat.jpg";
// 2. 加载并预处理图像
let img = image::open(image_path)?;
let resized = img.resize_exact(224, 224, image::imageops::FilterType::Triangle);
// 3. 转换为模型需要的输入格式 (NCHW)
let image = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
let pixel = resized.get_pixel(x as u32, y as u32);
// ImageNet数据集的标准归一化参数
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(pixel[c] as f32 / 255.0 - mean) / std
}).into();
// 4. 加载并优化模型
let model = tract_onnx::onnx()
.model_for_path(model_path)?
.into_optimized()?
.into_runnable()?;
// 5. 运行推理
let outputs = model.run(tvec!(image))?;
// 6. 处理输出结果
let output = outputs[0].to_array_view::<f32>()?;
let mut probabilities: Vec<(usize, f32)> = output.iter().enumerate().map(|(i, &v)| (i, v)).collect();
probabilities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// 打印前5个最可能的类别
println!("Top 5 predictions:");
for (i, (class, prob)) in probabilities.iter().take(5).enumerate() {
println!("{}. 类别 {}: {:.2}%", i+1, class, prob * 100.0);
}
Ok(())
}
要运行此示例,您需要:
- 准备一个ONNX格式的模型文件(如resnet18.onnx)
- 准备一张测试图像(如cat.jpg)
- 在Cargo.toml中添加依赖:
[dependencies]
tract-onnx = "0.18"
image = "0.24"
tract-core为Rust开发者提供了一个高效、安全的机器学习推理解决方案,特别适合需要将模型集成到Rust应用程序中的场景。