Rust机器学习推理引擎tract-core的使用,高性能ONNX/TensorFlow模型解析与执行

Rust机器学习推理引擎tract-core的使用,高性能ONNX/TensorFlow模型解析与执行

安装

在项目目录中运行以下Cargo命令:

cargo add tract-core

或者在Cargo.toml中添加以下行:

tract-core = "0.21.13"

基本使用示例

下面是一个使用tract-core加载和运行ONNX模型的完整示例:

use tract_onnx::prelude::*;

fn main() -> TractResult<()> {
    // 1. 加载ONNX模型
    let model = tract_onnx::onnx()
        // 从文件加载模型
        .model_for_path("path/to/model.onnx")?
        // 优化模型
        .into_optimized()?
        // 转换为可执行模型
        .into_runnable()?;

    // 2. 准备输入数据
    // 假设模型需要一个形状为[1, 3, 224, 224]的f32张量
    let input = Tensor::from_shape(
        &[1, 3, 224, 224],
        &vec![0.0f32; 1 * 3 * 224 * 224],
    )?;

    // 3. 运行推理
    let outputs = model.run(tvec!(input))?;

    // 4. 处理输出
    let output = outputs[0].to_array_view::<f32>()?;
    println!("模型输出: {:?}", output);

    Ok(())
}

加载TensorFlow模型示例

use tract_tensorflow::prelude::*;

fn main() -> TractResult<()> {
    // 1. 加载TensorFlow模型
    let model = tract_tensorflow::tensorflow()
        // 从文件加载模型
        .model_for_path("path/to/model.pb")?
        // 指定输入形状和类型
        .with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 224, 224, 3)))?
        // 优化模型
        .into_optimized()?
        // 转换为可执行模型
        .into_runnable()?;

    // 2. 准备输入数据
    let input = Tensor::from_shape(
        &[1, 224, 224, 3],
        &vec![0.0f32; 1 * 224 * 224 * 3],
    )?;

    // 3. 运行推理
    let outputs = model.run(tvec!(input))?;

    // 4. 处理输出
    let output = outputs[0].to_array_view::<f32>()?;
    println!("模型输出: {:?}", output);

    Ok(())
}

高级功能

动态输入形状

use tract_onnx::prelude::*;

fn main() -> TractResult<()> {
    let mut model = tract_onnx::onnx()
        .model_for_path("path/to/model.onnx")?
        // 声明动态输入形状
        .with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, Symbol::from('H'), Symbol::from('W')))?
        .into_optimized()?;

    // 指定具体输入尺寸
    model.set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224))?;
    
    let runnable = model.into_runnable()?;
    // ...运行推理
    Ok(())
}

性能优化

use tract_onnx::prelude::*;

fn main() -> TractResult<()> {
    let model = tract_onnx::onnx()
        .model_for_path("path/to/model.onnx")?
        // 启用更多优化
        .with_optimization_policy(tract_core::OptimizationPolicy::Max)
        // 指定目标架构以获得最佳性能
        .with_target(tract_core::Target::Host)
        .into_optimized()?
        .into_runnable()?;
    
    // ...运行推理
    Ok(())
}

完整示例

下面是一个完整的ONNX模型推理示例,包含错误处理和更详细的注释:

use tract_onnx::prelude::*;

fn main() {
    match run_inference() {
        Ok(_) => println!("推理成功完成"),
        Err(e) => println!("推理过程中发生错误: {}", e),
    }
}

fn run_inference() -> TractResult<()> {
    // 1. 加载并准备模型
    let model = tract_onnx::onnx()
        .model_for_path("path/to/model.onnx")?
        .with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
        .into_optimized()?
        .into_runnable()?;

    // 2. 创建随机输入数据 (实际应用中应使用真实数据)
    let input = Tensor::rand::<f32>(&[1, 3, 224, 224]);

    // 3. 运行推理并测量时间
    use std::time::Instant;
    let start = Instant::now();
    let outputs = model.run(tvec!(input))?;
    let duration = start.elapsed();
    println!("推理耗时: {:?}", duration);

    // 4. 处理输出
    let output = outputs[0].to_array_view::<f32>()?;
    println!("输出形状: {:?}", output.shape());
    
    // 获取预测结果 (假设是分类模型)
    if let Some((max_idx, max_val)) = output.iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
    {
        println!("预测类别: {}, 置信度: {}", max_idx, max_val);
    }

    Ok(())
}

1 回复

Rust机器学习推理引擎tract-core的使用指南

什么是tract-core?

tract-core是一个纯Rust实现的轻量级机器学习推理引擎,专注于高性能的模型解析与执行。它支持ONNX和TensorFlow模型格式,能够在Rust生态系统中提供高效的神经网络推理能力。

主要特性

  • 支持ONNX和TensorFlow模型导入
  • 无运行时依赖的纯Rust实现
  • 高性能的模型执行
  • 支持CPU上的多线程执行
  • 提供简单易用的API

安装方法

在Cargo.toml中添加依赖:

[dependencies]
tract-core = "0.18"
tract-onnx = "0.18"  # 如果需要ONNX支持
tract-tensorflow = "0.18"  # 如果需要TensorFlow支持

基本使用方法

加载并运行ONNX模型

use tract_onnx::prelude::*;

fn main() -> TractResult<()> {
    // 加载ONNX模型
    let model = tract_onnx::onnx()
        .model_for_path("path/to/model.onnx")?
        .into_optimized()?
        .into_runnable()?;

    // 准备输入数据 (示例使用随机数据)
    let input = tensor1(&[0f32, 1.0, 2.0, 3.0, 4.0]).into_shape(&[1, 5])?;

    // 运行模型
    let outputs = model.run(tvec!(input))?;

    // 处理输出
    println!("模型输出: {:?}", outputs[0]);
    Ok(())
}

加载并运行TensorFlow模型

use tract_tensorflow::prelude::*;

fn main() -> TractResult<()> {
    // 加载TensorFlow模型
    let model = tract_tensorflow::tensorflow()
        .model_for_path("path/to/model.pb")?
        .into_optimized()?
        .into_runnable()?;

    // 准备输入数据
    let input = tensor1(&[1f32, 2.0, 3.0]).into_shape(&[1, 3])?;

    // 运行模型
    let outputs = model.run(tvec!(input))?;

    // 处理输出
    println!("模型输出: {:?}", outputs[0]);
    Ok(())
}

高级用法

指定输入输出节点名称

let model = tract_onnx::onnx()
    .model_for_path("model.onnx")?
    .with_input_names(vec!["input_name"])?
    .with_output_names(vec!["output_name"])?
    .into_optimized()?
    .into_runnable()?;

使用多线程执行

let mut model = tract_onnx::onnx()
    .model_for_path("model.onnx")?
    .into_optimized()?
    .into_runnable()?;

// 设置线程数
model.set_concurrency(4);

性能优化建议

  1. 使用into_optimized()对模型进行优化
  2. 对于重复执行,复用RunnableModel实例
  3. 合理设置并发线程数
  4. 考虑使用定点的量化模型减少计算量

常见问题解决

  1. 模型加载失败:检查模型路径是否正确,确保模型格式与使用的加载器匹配
  2. 输入形状不匹配:使用input.into_shape()调整输入张量的形状
  3. 性能问题:尝试使用into_optimized()优化模型,并调整并发设置

实际应用示例:图像分类

use tract_onnx::prelude::*;
use image::GenericImageView;

fn classify_image(model_path: &str, image_path: &str) -> TractResult<()> {
    // 加载模型
    let model = tract_onnx::onnx()
        .model_for_path(model_path)?
        .into_optimized()?
        .into_runnable()?;

    // 加载并预处理图像
    let img = image::open(image_path)?;
    let resized = img.resize_exact(224, 224, image::imageops::FilterType::Triangle);
    let image = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
        let pixel = resized.get_pixel(x as u32, y as u32);
        let mean = [0.485, 0.456, 0.406][c];
        let std = [0.229, 0.224, 0.225][c];
        (pixel[c] as f32 / 255.0 - mean) / std
    }).into();

    // 运行推理
    let outputs = model.run(tvec!(image))?;
    let best = outputs[0]
        .to_array_view::<f32>()?
        .iter()
        .enumerate()
        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap());

    println!("预测结果: 类别 {} 置信度 {}", best.unwrap().0, best.unwrap().1);
    Ok(())
}

完整示例代码

下面是一个完整的图像分类示例,包含模型加载、图像预处理和推理全过程:

use tract_onnx::prelude::*;
use image::{GenericImageView, ImageBuffer, Rgb};
use std::path::Path;

fn main() -> TractResult<()> {
    // 1. 加载ONNX模型
    let model_path = "resnet18.onnx";
    let image_path = "cat.jpg";
    
    // 2. 加载并预处理图像
    let img = image::open(image_path)?;
    let resized = img.resize_exact(224, 224, image::imageops::FilterType::Triangle);
    
    // 3. 转换为模型需要的输入格式 (NCHW)
    let image = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
        let pixel = resized.get_pixel(x as u32, y as u32);
        // ImageNet数据集的标准归一化参数
        let mean = [0.485, 0.456, 0.406][c];
        let std = [0.229, 0.224, 0.225][c];
        (pixel[c] as f32 / 255.0 - mean) / std
    }).into();

    // 4. 加载并优化模型
    let model = tract_onnx::onnx()
        .model_for_path(model_path)?
        .into_optimized()?
        .into_runnable()?;

    // 5. 运行推理
    let outputs = model.run(tvec!(image))?;
    
    // 6. 处理输出结果
    let output = outputs[0].to_array_view::<f32>()?;
    let mut probabilities: Vec<(usize, f32)> = output.iter().enumerate().map(|(i, &v)| (i, v)).collect();
    probabilities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
    
    // 打印前5个最可能的类别
    println!("Top 5 predictions:");
    for (i, (class, prob)) in probabilities.iter().take(5).enumerate() {
        println!("{}. 类别 {}: {:.2}%", i+1, class, prob * 100.0);
    }

    Ok(())
}

要运行此示例,您需要:

  1. 准备一个ONNX格式的模型文件(如resnet18.onnx)
  2. 准备一张测试图像(如cat.jpg)
  3. 在Cargo.toml中添加依赖:
[dependencies]
tract-onnx = "0.18"
image = "0.24"

tract-core为Rust开发者提供了一个高效、安全的机器学习推理解决方案,特别适合需要将模型集成到Rust应用程序中的场景。

回到顶部