Rust AWS SDK库aws-sdk-sagemakerruntime的使用:实现与Amazon SageMaker Runtime服务的无缝集成

Rust AWS SDK库aws-sdk-sagemakerruntime的使用:实现与Amazon SageMaker Runtime服务的无缝集成

Amazon SageMaker Runtime API的Rust SDK实现。

开始使用

SDK为每个AWS服务提供一个crate。您必须在Rust项目中添加Tokio作为依赖项以执行异步代码。要将aws-sdk-sagemakerruntime添加到您的项目中,请在Cargo.toml文件中添加以下内容:

[dependencies]
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-sagemakerruntime = "1.83.0"
tokio = { version = "1", features = ["full"] }

然后在代码中,可以按以下方式创建客户端:

use aws_sdk_sagemakerruntime as sagemakerruntime;

#[::tokio::main]
async fn main() -> Result<(), sagemakerruntime::Error> {
    let config = aws_config::load_from_env().await;
    let client = aws_sdk_sagemakerruntime::Client::new(&config);

    // ... 使用客户端进行调用

    Ok(())
}

完整示例

以下是一个完整的示例,展示如何使用aws-sdk-sagemakerruntime与Amazon SageMaker Runtime服务交互:

use aws_sdk_sagemakerruntime as sagemakerruntime;
use aws_sdk_sagemakerruntime::types::Blob;
use std::error::Error;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    // 加载AWS配置(从环境变量/配置文件等)
    let config = aws_config::load_from_env().await;
    
    // 创建SageMaker Runtime客户端
    let client = sagemakerruntime::Client::new(&config);

    // 准备输入数据(这里使用示例JSON数据)
    let input_data = r#"{"instances": [{"data": [1.0, 2.0, 3.0, 4.0]}]}"#;
    let blob = Blob::new(input_data.as_bytes().to_vec());

    // 调用SageMaker端点进行预测
    let endpoint_name = "your-endpoint-name"; // 替换为您的SageMaker端点名称
    let content_type = "application/json";   // 根据您的模型调整内容类型
    
    let resp = client.invoke_endpoint()
        .endpoint_name(endpoint_name)
        .content_type(content_type)
        .body(blob)
        .send()
        .await?;

    // 处理响应
    if let Some(output) = resp.body() {
        println!("预测结果: {:?}", String::from_utf8_lossy(output.as_ref()));
    }

    Ok(())
}

使用说明

  1. 确保您已正确配置AWS凭证(通过环境变量或~/.aws/credentials文件)
  2. 替换示例中的your-endpoint-name为您实际的SageMaker端点名称
  3. 根据您的模型调整输入数据的格式和内容类型

许可证

该项目基于Apache-2.0许可证。


1 回复

Rust AWS SDK库aws-sdk-sagemakerruntime的使用指南

介绍

aws-sdk-sagemakerruntime是Rust语言的AWS SDK的一部分,它允许开发者与Amazon SageMaker Runtime服务进行交互。这个库使Rust应用程序能够轻松调用部署在SageMaker上的机器学习模型端点,发送推理请求并接收预测结果。

Amazon SageMaker Runtime服务提供了低延迟、高吞吐量的API,用于与部署在SageMaker上的模型进行交互。使用这个Rust SDK,你可以将机器学习预测功能集成到你的Rust应用程序中。

安装

在Cargo.toml中添加依赖:

[dependencies]
aws-config = "0.55"
aws-sdk-sagemakerruntime = "0.25"
tokio = { version = "1", features = ["full"] }

基本使用方法

1. 创建客户端

首先需要创建一个SageMaker Runtime客户端:

use aws_sdk_sagemakerruntime::Client;

#[tokio::main]
async fn main() -> Result<(), aws_sdk_sagemakerruntime::Error> {
    let config = aws_config::load_from_env().await;
    let client = Client::new(&config);
    
    // 后续操作...
    Ok(())
}

2. 调用模型端点进行预测

use aws_sdk_sagemakerruntime::types::Blob;

async fn invoke_endpoint(
    client: &Client,
    endpoint_name: &str,
    input_data: Vec<u8>,
) -> Result<(), aws_sdk_sagemakerruntime::Error> {
    let response = client
        .invoke_endpoint()
        .endpoint_name(endpoint_name)
        .content_type("application/json")  // 根据你的模型调整
        .body(Blob::new(input_data))
        .send()
        .await?;

    // 处理响应
    let prediction = response.body.collect().await?;
    println!("Prediction result: {:?}", prediction.into_bytes());
    
    Ok(())
}

完整示例

下面是一个完整的示例,展示如何调用SageMaker端点进行预测:

use aws_sdk_sagemakerruntime::{Client, types::Blob};
use serde_json::json;

#[tokio::main]
async fn main() -> Result<(), aws_sdk_sagemakerruntime::Error> {
    // 加载AWS配置
    let config = aws_config::load_from_env().await;
    let client = Client::new(&config);
    
    // 准备输入数据
    let input_data = json!({
        "features": [5.1, 3.5, 1.4, 0.2]
    });
    let input_bytes = serde_json::to_vec(&input_data).unwrap();
    
    // 调用端点
    let endpoint_name = "my-sagemaker-endpoint";
    let response = client
        .invoke_endpoint()
        .endpoint_name(endpoint_name)
        .content_type("application/json")
        .body(Blob::new(input_bytes))
        .send()
        .await?;
    
    // 处理响应
    let prediction = response.body.collect().await?;
    let prediction_bytes = prediction.into_bytes();
    
    // 假设响应是JSON格式
    let prediction_result: serde_json::Value = serde_json::from_slice(&prediction_bytes).unwrap();
    println!("Prediction result: {}", prediction_result);
    
    Ok(())
}

高级功能

1. 异步流式处理

对于大输入或输出,可以使用流式处理:

use bytes::Bytes;
use futures::stream;
use aws_sdk_sagemakerruntime::primitives::ByteStream;

async fn invoke_endpoint_streaming(
    client: &Client,
    endpoint_name: &str,
) -> Result<(), aws_sdk_sagemakerruntime::Error> {
    // 创建流式输入
    let data_stream = stream::iter(vec![
        Ok(Bytes::from("part1")),
        Ok(Bytes::from("part2")),
    ]);
    
    let response = client
        .invoke_endpoint()
        .endpoint_name(endpoint_name)
        .content_type("application/octet-stream")
        .body(ByteStream::new(data_stream))
        .send()
        .await?;
    
    // 流式处理响应
    let mut response_stream = response.body.into_async_read();
    // ... 处理流数据
    
    Ok(())
}

2. 错误处理

async fn safe_invoke(
    client: &Client,
    endpoint_name: &str,
    input: Vec<u8>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
    let response = client
        .invoke_endpoint()
        .endpoint_name(endpoint_name)
        .body(Blob::new(input))
        .send()
        .await
        .map_err(|e| {
            eprintln!("Failed to invoke endpoint: {}", e);
            e
        })?;
    
    let prediction = response.body
        .collect()
        .await
        .map_err(|e| {
            eprintln!("Failed to read response: {}", e);
            e
        })?;
    
    Ok(prediction.into_bytes())
}

最佳实践

  1. 重用客户端:AWS客户端设计为可重用且线程安全,应该尽可能重用而不是每次请求都创建新的

  2. 设置合理的超时:可以通过AWS配置设置适当的超时时间

  3. 处理限流:实现重试逻辑处理SageMaker的限流响应

  4. 内容类型匹配:确保content_type与模型期望的输入格式匹配

  5. 监控和日志:记录请求和响应大小、延迟等指标

总结

aws-sdk-sagemakerruntime库为Rust开发者提供了与Amazon SageMaker Runtime服务交互的高效方式。通过这个库,你可以轻松地将部署在SageMaker上的机器学习模型集成到Rust应用程序中,实现低延迟的预测功能。

完整示例Demo

下面是一个增强版的完整示例,包含错误处理、日志记录和配置设置:

use aws_sdk_sagemakerruntime::{Client, types::Blob};
use serde_json::json;
use tracing::info;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 初始化日志
    tracing_subscriber::fmt::init();
    
    // 加载AWS配置(可自定义超时等设置)
    let config = aws_config::from_env()
        .timeout_config(
            aws_config::timeout::TimeoutConfig::builder()
                .operation_timeout(std::time::Duration::from_secs(30))
                .build()
        )
        .load()
        .await;
    
    // 创建客户端
    let client = Client::new(&config);
    
    // 准备输入数据
    let input_data = json!({
        "features": [5.1, 3.5, 1.4, 0.2],
        "additional_params": {
            "mode": "predict"
        }
    });
    let input_bytes = serde_json::to_vec(&input_data)?;
    info!("Prepared input data with size: {} bytes", input_bytes.len());
    
    // 调用端点
    let endpoint_name = "my-sagemaker-endpoint";
    info!("Invoking endpoint: {}", endpoint_name);
    
    let response = match client
        .invoke_endpoint()
        .endpoint_name(endpoint_name)
        .content_type("application/json")
        .body(Blob::new(input_bytes))
        .send()
        .await 
    {
        Ok(resp) => resp,
        Err(e) => {
            tracing::error!("Failed to invoke endpoint: {}", e);
            return Err(e.into());
        }
    };
    
    // 处理响应
    let prediction = response.body
        .collect()
        .await
        .map_err(|e| {
            tracing::error!("Failed to read response: {}", e);
            e
        })?;
    
    let prediction_bytes = prediction.into_bytes();
    info!("Received prediction with size: {} bytes", prediction_bytes.len());
    
    // 解析JSON响应
    let prediction_result: serde_json::Value = serde_json::from_slice(&prediction_bytes)?;
    println!("Final prediction result: {}", prediction_result);
    
    Ok(())
}

这个增强版示例包含以下改进:

  1. 添加了日志记录功能,使用tracing库记录关键操作
  2. 配置了自定义超时设置
  3. 提供了更详细的错误处理
  4. 包含了输入/输出数据大小的日志记录
  5. 使用了更结构化的输入数据格式

要运行此示例,需要在Cargo.toml中添加额外的依赖:

[dependencies]
tracing = "0.1"
tracing-subscriber = "0.3"
回到顶部