Rust AWS SDK库aws-sdk-sagemakerruntime的使用:实现与Amazon SageMaker Runtime服务的无缝集成
Rust AWS SDK库aws-sdk-sagemakerruntime的使用:实现与Amazon SageMaker Runtime服务的无缝集成
Amazon SageMaker Runtime API的Rust SDK实现。
开始使用
SDK为每个AWS服务提供一个crate。您必须在Rust项目中添加Tokio作为依赖项以执行异步代码。要将aws-sdk-sagemakerruntime
添加到您的项目中,请在Cargo.toml文件中添加以下内容:
[dependencies]
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-sagemakerruntime = "1.83.0"
tokio = { version = "1", features = ["full"] }
然后在代码中,可以按以下方式创建客户端:
use aws_sdk_sagemakerruntime as sagemakerruntime;
#[::tokio::main]
async fn main() -> Result<(), sagemakerruntime::Error> {
let config = aws_config::load_from_env().await;
let client = aws_sdk_sagemakerruntime::Client::new(&config);
// ... 使用客户端进行调用
Ok(())
}
完整示例
以下是一个完整的示例,展示如何使用aws-sdk-sagemakerruntime与Amazon SageMaker Runtime服务交互:
use aws_sdk_sagemakerruntime as sagemakerruntime;
use aws_sdk_sagemakerruntime::types::Blob;
use std::error::Error;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// 加载AWS配置(从环境变量/配置文件等)
let config = aws_config::load_from_env().await;
// 创建SageMaker Runtime客户端
let client = sagemakerruntime::Client::new(&config);
// 准备输入数据(这里使用示例JSON数据)
let input_data = r#"{"instances": [{"data": [1.0, 2.0, 3.0, 4.0]}]}"#;
let blob = Blob::new(input_data.as_bytes().to_vec());
// 调用SageMaker端点进行预测
let endpoint_name = "your-endpoint-name"; // 替换为您的SageMaker端点名称
let content_type = "application/json"; // 根据您的模型调整内容类型
let resp = client.invoke_endpoint()
.endpoint_name(endpoint_name)
.content_type(content_type)
.body(blob)
.send()
.await?;
// 处理响应
if let Some(output) = resp.body() {
println!("预测结果: {:?}", String::from_utf8_lossy(output.as_ref()));
}
Ok(())
}
使用说明
- 确保您已正确配置AWS凭证(通过环境变量或~/.aws/credentials文件)
- 替换示例中的
your-endpoint-name
为您实际的SageMaker端点名称 - 根据您的模型调整输入数据的格式和内容类型
许可证
该项目基于Apache-2.0许可证。
Rust AWS SDK库aws-sdk-sagemakerruntime的使用指南
介绍
aws-sdk-sagemakerruntime
是Rust语言的AWS SDK的一部分,它允许开发者与Amazon SageMaker Runtime服务进行交互。这个库使Rust应用程序能够轻松调用部署在SageMaker上的机器学习模型端点,发送推理请求并接收预测结果。
Amazon SageMaker Runtime服务提供了低延迟、高吞吐量的API,用于与部署在SageMaker上的模型进行交互。使用这个Rust SDK,你可以将机器学习预测功能集成到你的Rust应用程序中。
安装
在Cargo.toml中添加依赖:
[dependencies]
aws-config = "0.55"
aws-sdk-sagemakerruntime = "0.25"
tokio = { version = "1", features = ["full"] }
基本使用方法
1. 创建客户端
首先需要创建一个SageMaker Runtime客户端:
use aws_sdk_sagemakerruntime::Client;
#[tokio::main]
async fn main() -> Result<(), aws_sdk_sagemakerruntime::Error> {
let config = aws_config::load_from_env().await;
let client = Client::new(&config);
// 后续操作...
Ok(())
}
2. 调用模型端点进行预测
use aws_sdk_sagemakerruntime::types::Blob;
async fn invoke_endpoint(
client: &Client,
endpoint_name: &str,
input_data: Vec<u8>,
) -> Result<(), aws_sdk_sagemakerruntime::Error> {
let response = client
.invoke_endpoint()
.endpoint_name(endpoint_name)
.content_type("application/json") // 根据你的模型调整
.body(Blob::new(input_data))
.send()
.await?;
// 处理响应
let prediction = response.body.collect().await?;
println!("Prediction result: {:?}", prediction.into_bytes());
Ok(())
}
完整示例
下面是一个完整的示例,展示如何调用SageMaker端点进行预测:
use aws_sdk_sagemakerruntime::{Client, types::Blob};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), aws_sdk_sagemakerruntime::Error> {
// 加载AWS配置
let config = aws_config::load_from_env().await;
let client = Client::new(&config);
// 准备输入数据
let input_data = json!({
"features": [5.1, 3.5, 1.4, 0.2]
});
let input_bytes = serde_json::to_vec(&input_data).unwrap();
// 调用端点
let endpoint_name = "my-sagemaker-endpoint";
let response = client
.invoke_endpoint()
.endpoint_name(endpoint_name)
.content_type("application/json")
.body(Blob::new(input_bytes))
.send()
.await?;
// 处理响应
let prediction = response.body.collect().await?;
let prediction_bytes = prediction.into_bytes();
// 假设响应是JSON格式
let prediction_result: serde_json::Value = serde_json::from_slice(&prediction_bytes).unwrap();
println!("Prediction result: {}", prediction_result);
Ok(())
}
高级功能
1. 异步流式处理
对于大输入或输出,可以使用流式处理:
use bytes::Bytes;
use futures::stream;
use aws_sdk_sagemakerruntime::primitives::ByteStream;
async fn invoke_endpoint_streaming(
client: &Client,
endpoint_name: &str,
) -> Result<(), aws_sdk_sagemakerruntime::Error> {
// 创建流式输入
let data_stream = stream::iter(vec![
Ok(Bytes::from("part1")),
Ok(Bytes::from("part2")),
]);
let response = client
.invoke_endpoint()
.endpoint_name(endpoint_name)
.content_type("application/octet-stream")
.body(ByteStream::new(data_stream))
.send()
.await?;
// 流式处理响应
let mut response_stream = response.body.into_async_read();
// ... 处理流数据
Ok(())
}
2. 错误处理
async fn safe_invoke(
client: &Client,
endpoint_name: &str,
input: Vec<u8>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let response = client
.invoke_endpoint()
.endpoint_name(endpoint_name)
.body(Blob::new(input))
.send()
.await
.map_err(|e| {
eprintln!("Failed to invoke endpoint: {}", e);
e
})?;
let prediction = response.body
.collect()
.await
.map_err(|e| {
eprintln!("Failed to read response: {}", e);
e
})?;
Ok(prediction.into_bytes())
}
最佳实践
-
重用客户端:AWS客户端设计为可重用且线程安全,应该尽可能重用而不是每次请求都创建新的
-
设置合理的超时:可以通过AWS配置设置适当的超时时间
-
处理限流:实现重试逻辑处理SageMaker的限流响应
-
内容类型匹配:确保content_type与模型期望的输入格式匹配
-
监控和日志:记录请求和响应大小、延迟等指标
总结
aws-sdk-sagemakerruntime
库为Rust开发者提供了与Amazon SageMaker Runtime服务交互的高效方式。通过这个库,你可以轻松地将部署在SageMaker上的机器学习模型集成到Rust应用程序中,实现低延迟的预测功能。
完整示例Demo
下面是一个增强版的完整示例,包含错误处理、日志记录和配置设置:
use aws_sdk_sagemakerruntime::{Client, types::Blob};
use serde_json::json;
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 初始化日志
tracing_subscriber::fmt::init();
// 加载AWS配置(可自定义超时等设置)
let config = aws_config::from_env()
.timeout_config(
aws_config::timeout::TimeoutConfig::builder()
.operation_timeout(std::time::Duration::from_secs(30))
.build()
)
.load()
.await;
// 创建客户端
let client = Client::new(&config);
// 准备输入数据
let input_data = json!({
"features": [5.1, 3.5, 1.4, 0.2],
"additional_params": {
"mode": "predict"
}
});
let input_bytes = serde_json::to_vec(&input_data)?;
info!("Prepared input data with size: {} bytes", input_bytes.len());
// 调用端点
let endpoint_name = "my-sagemaker-endpoint";
info!("Invoking endpoint: {}", endpoint_name);
let response = match client
.invoke_endpoint()
.endpoint_name(endpoint_name)
.content_type("application/json")
.body(Blob::new(input_bytes))
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::error!("Failed to invoke endpoint: {}", e);
return Err(e.into());
}
};
// 处理响应
let prediction = response.body
.collect()
.await
.map_err(|e| {
tracing::error!("Failed to read response: {}", e);
e
})?;
let prediction_bytes = prediction.into_bytes();
info!("Received prediction with size: {} bytes", prediction_bytes.len());
// 解析JSON响应
let prediction_result: serde_json::Value = serde_json::from_slice(&prediction_bytes)?;
println!("Final prediction result: {}", prediction_result);
Ok(())
}
这个增强版示例包含以下改进:
- 添加了日志记录功能,使用tracing库记录关键操作
- 配置了自定义超时设置
- 提供了更详细的错误处理
- 包含了输入/输出数据大小的日志记录
- 使用了更结构化的输入数据格式
要运行此示例,需要在Cargo.toml中添加额外的依赖:
[dependencies]
tracing = "0.1"
tracing-subscriber = "0.3"