Rust AWS机器学习开发库aws-sdk-sagemaker的使用:集成Amazon SageMaker服务实现模型训练与部署

Rust AWS机器学习开发库aws-sdk-sagemaker的使用:集成Amazon SageMaker服务实现模型训练与部署

介绍

aws-sdk-sagemaker提供了创建和管理SageMaker资源的API。

开始使用

SDK为每个AWS服务提供一个crate。您需要在Rust项目中添加Tokio作为依赖项来执行异步代码。要将aws-sdk-sagemaker添加到您的项目中,请在Cargo.toml文件中添加以下内容:

[dependencies]
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-sagemaker = "1.144.0"
tokio = { version = "1", features = ["full"] }

然后在代码中,可以创建客户端如下:

use aws_sdk_sagemaker as sagemaker;

#[::tokio::main]
async fn main() -> Result<(), sagemaker::Error> {
    let config = aws_config::load_from_env().await;
    let client = aws_sdk_sagemaker::Client::new(&config);

    // ... 使用客户端进行调用
    
    Ok(())
}

完整示例

以下是一个更完整的示例,展示如何使用aws-sdk-sagemaker创建训练任务并部署模型:

use aws_sdk_sagemaker::{Client, Error};
use aws_sdk_sagemaker::types::{
    AlgorithmSpecification, 
    DataSource, 
    S3DataSource, 
    InputDataConfig, 
    OutputDataConfig, 
    ResourceConfig, 
    StoppingCondition, 
    TrainingInputMode,
    ContainerDefinition,
    ProductionVariant
};

#[tokio::main]
async fn main() -> Result<(), Error> {
    // 加载AWS配置
    let config = aws_config::load_from_env().await;
    let client = Client::new(&config);

    // 1. 创建训练任务
    let training_job = client.create_training_job()
        .training_job_name("rust-sagemaker-example")
        .algorithm_specification(
            AlgorithmSpecification::builder()
                .training_image("your-training-image") // 替换为实际的训练镜像
                .training_input_mode(TrainingInputMode::File)
                .build()
        )
        .input_data_config(
            InputDataConfig::builder()
                .channel_name("train")
                .data_source(
                    DataSource::builder()
                        .s3_data_source(
                            S3DataSource::builder()
                                .s3_data_type("S3Prefix")
                                .s3_uri("s3://your-bucket/training-data/") // 替换为实际的数据路径
                                .build()
                        )
                        .build()
                )
                .build()
        )
        .output_data_config(
            OutputDataConfig::builder()
                .s3_output_path("s3://your-bucket/output/") // 替换为实际的输出路径
                .build()
        )
        .resource_config(
            ResourceConfig::builder()
                .instance_type("ml.m4.xlarge")
                .instance_count(1)
                .volume_size_in_gb(50)
                .build()
        )
        .stopping_condition(
            StoppingCondition::builder()
                .max_runtime_in_seconds(3600)
                .build()
        )
        .role_arn("arn:aws:iam::your-account-id:role/your-role") // 替换为实际的IAM角色
        .send()
        .await?;

    println!("训练任务创建成功: {:?}", training_job.training_job_arn);

    // 2. 部署模型
    let model = client.create_model()
        .model_name("rust-sagemaker-model")
        .execution_role_arn("arn:aws:iam::your-account-id:role/your-role") // 替换为实际的IAM角色
        .primary_container(
            ContainerDefinition::builder()
                .image("your-inference-image") // 替换为实际的推理镜像
                .model_data_url("s3://your-bucket/output/model.tar.gz") // 替换为实际的模型路径
                .build()
        )
        .send()
        .await?;

    println!("模型创建成功: {:?}", model.model_arn);

    // 3. 创建端点配置
    let endpoint_config = client.create_endpoint_config()
        .endpoint_config_name("rust-sagemaker-config")
        .production_variants(
            ProductionVariant::builder()
                .variant_name("all-traffic")
                .model_name("rust-sagemaker-model")
                .initial_instance_count(1)
                .instance_type("ml.m4.xlarge")
                .build()
        )
        .send()
        .await?;

    println!("端点配置创建成功: {:?}", endpoint_config.endpoint_config_arn);

    // 4. 创建端点
    let endpoint = client.create_endpoint()
        .endpoint_name("rust-sagemaker-endpoint")
        .endpoint_config_name("rust-sagemaker-config")
        .send()
        .await?;

    println!("端点创建成功: {:?}", endpoint.endpoint_arn);

    Ok(())
}

许可证

该项目使用Apache-2.0许可证。


1 回复

Rust AWS机器学习开发库aws-sdk-sagemaker的使用:集成Amazon SageMaker服务实现模型训练与部署

介绍

aws-sdk-sagemaker是AWS官方提供的Rust SDK,用于与Amazon SageMaker服务进行交互。Amazon SageMaker是一项完全托管的机器学习服务,使开发人员和数据科学家能够快速构建、训练和部署机器学习模型。

这个Rust库提供了对SageMaker服务的完整访问能力,包括:

  • 创建和管理训练任务
  • 部署模型到终端节点
  • 管理模型和端点
  • 处理训练数据和模型工件

使用方法

1. 添加依赖

首先,在Cargo.toml中添加依赖:

[dependencies]
aws-config = "0.55"
aws-sdk-sagemaker = "0.28"
tokio = { version = "1", features = ["full"] }

2. 基本设置

use aws_config::BehaviorVersion;
use aws_sdk_sagemaker::{Client, Error};

async fn create_client() -> Result<Client, Error> {
    let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
    let client = Client::new(&config);
    Ok(client)
}

3. 创建训练任务

use aws_sdk_sagemaker::types::{TrainingInputMode, S3DataSource, Channel, ResourceConfig, TrainingInstanceType};

async fn create_training_job(client: &Client) -> Result<(), Error> {
    let job_name = "rust-sagemaker-example";
    
    let s3_data_source = S3DataSource::builder()
        .s3_data_type("S3Prefix")
        .s3_uri("s3://your-bucket/training-data/")
        .build();
        
    let channel = Channel::builder()
        .channel_name("train")
        .data_source(s3_data_source.into())
        .input_mode(TrainingInputMode::File)
        .build();
        
    let resource_config = ResourceConfig::builder()
        .instance_count(1)
        .instance_type(TrainingInstanceType::MlM5Large)
        .volume_size_in_gb(30)
        .build();
        
    let output_config = aws_sdk-sagemaker::types::OutputDataConfig::builder()
        .s3_output_path("s3://your-bucket/output/")
        .build();
        
    let response = client.create_training_job()
        .training_job_name(job_name)
        .algorithm_specification(
            aws_sdk_sagemaker::types::AlgorithmSpecification::builder()
                .training_image("your-training-image-uri")
                .training_input_mode(TrainingInputMode::File)
                .build()
        )
        .role_arn("arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole")
        .input_data_config(channel)
        .output_data_config(output_config)
        .resource_config(resource_config)
        .stopping_condition(
            aws_sdk_sagemaker::types::StoppingCondition::builder()
                .max_runtime_in_seconds(3600)
                .build()
        )
        .send()
        .await?;
        
    println!("Training job created: {:?}", response.training_job_arn);
    Ok(())
}

4. 部署模型

async fn deploy_model(client: &Client, model_name: &str) -> Result<(), Error> {
    let endpoint_config_name = format!("{}-config", model_name);
    let endpoint_name = format!("{}-endpoint", model_name);
    
    // 首先创建模型
    let create_model_response = client.create_model()
        .model_name(model_name)
        .execution_role_arn("arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole")
        .primary_container(
            aws_sdk_sagemaker::types::ContainerDefinition::builder()
                .image("your-inference-image-uri")
                .model_data_url("s3://your-bucket/model/model.tar.gz")
                .build()
        )
        .send()
        .await?;
    
    // 创建端点配置
    let endpoint_config_response = client.create_endpoint_config()
        .endpoint_config_name(&endpoint_config_name)
        .production_variants(
            aws_sdk_sagemaker::types::ProductionVariant::builder()
                .variant_name("AllTraffic")
                .model_name(model_name)
                .initial_instance_count(1)
                .instance_type("ml.m5.large")
                .build()
        )
        .send()
        .await?;
    
    // 创建端点
    let endpoint_response = client.create_endpoint()
        .endpoint_name(&endpoint_name)
        .endpoint_config_name(&endpoint_config_name)
        .send()
        .await?;
    
    println!("Endpoint created: {:?}", endpoint_response.endpoint_arn);
    Ok(())
}

5. 查询训练任务状态

async fn describe_training_job(client: &Client, job_name: &str) -> Result<(), Error> {
    let response = client.describe_training_job()
        .training_job_name(job_name)
        .send()
        .await?;
    
    println!("Training job status: {:?}", response.training_job_status);
    Ok(())
}

完整示例

下面是一个完整的示例,展示如何创建训练任务并部署模型:

use aws_config::BehaviorVersion;
use aws_sdk_sagemaker::{Client, Error};
use aws_sdk_sagemaker::types::{TrainingInputMode, S3DataSource, Channel, ResourceConfig, TrainingInstanceType};

#[tokio::main]
async fn main() -> Result<(), Error> {
    // 创建客户端
    let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
    let client = Client::new(&config);
    
    // 创建训练任务
    let job_name = "rust-sagemaker-example";
    create_training_job(&client, job_name).await?;
    
    // 检查训练状态
    describe_training_job(&client, job_name).await?;
    
    // 部署模型
    let model_name = "rust-sagemaker-model";
    deploy_model(&client, model_name).await?;
    
    Ok(())
}

async fn create_training_job(client: &Client, job_name: &str) -> Result<(), Error> {
    let s3_data_source = S3DataSource::builder()
        .s3_data_type("S3Prefix")
        .s3_uri("s3://your-bucket/training-data/")
        .build();
        
    let channel = Channel::builder()
        .channel_name("train")
        .data_source(s3_data_source.into())
        .input_mode(TrainingInputMode::File)
        .build();
        
    let resource_config = ResourceConfig::builder()
        .instance_count(1)
        .instance_type(TrainingInstanceType::MlM5Large)
        .volume_size_in_gb(30)
        .build();
        
    let output_config = aws_sdk_sagemaker::types::OutputDataConfig::builder()
        .s3_output_path("s3://your-bucket/output/")
        .build();
        
    let response = client.create_training_job()
        .training_job_name(job_name)
        .algorithm_specification(
            aws_sdk_sagemaker::types::AlgorithmSpecification::builder()
                .training_image("your-training-image-uri")
                .training_input_mode(TrainingInputMode::File)
                .build()
        )
        .role_arn("arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole")
        .input_data_config(channel)
        .output_data_config(output_config)
        .resource_config(resource_config)
        .stopping_condition(
            aws_sdk_sagemaker::types::StoppingCondition::builder()
                .max_runtime_in_seconds(3600)
                .build()
        )
        .send()
        .await?;
        
    println!("Training job created: {:?}", response.training_job_arn);
    Ok(())
}

async fn deploy_model(client: &Client, model_name: &str) -> Result<(), Error> {
    let endpoint_config_name = format!("{}-config", model_name);
    let endpoint_name = format!("{}-endpoint", model_name);
    
    let create_model_response = client.create_model()
        .model_name(model_name)
        .execution_role_arn("arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole")
        .primary_container(
            aws_sdk_sagemaker::types::ContainerDefinition::builder()
                .image("your-inference-image-uri")
                .model_data_url("s3://your-bucket/model/model.tar.gz")
                .build()
        )
        .send()
        .await?;
    
    let endpoint_config_response = client.create_endpoint_config()
        .endpoint_config_name(&endpoint_config_name)
        .production_variants(
            aws_sdk_sagemaker::types::ProductionVariant::builder()
                .variant_name("AllTraffic")
                .model_name(model_name)
                .initial_instance_count(1)
                .instance_type("ml.m5.large")
                .build()
        )
        .send()
        .await?;
    
    let endpoint_response = client.create_endpoint()
        .endpoint_name(&endpoint_name)
        .endpoint_config_name(&endpoint_config_name)
        .send()
        .await?;
    
    println!("Endpoint created: {:?}", endpoint_response.endpoint_arn);
    Ok(())
}

async fn describe_training_job(client: &Client, job_name: &str) -> Result<(), Error> {
    let response = client.describe_training_job()
        .training_job_name(job_name)
        .send()
        .await?;
    
    println!("Training job status: {:?}", response.training_job_status);
    Ok(())
}

最佳实践

  1. 错误处理:AWS SDK返回的Result类型应妥善处理,考虑使用?操作符或match表达式

  2. 异步操作:SageMaker操作通常是长时间运行的,确保使用tokio运行时并正确处理异步

  3. 资源清理:记得在完成后删除不再需要的模型、端点和训练任务以避免不必要的费用

  4. IAM权限:确保执行角色具有足够的权限访问S3桶和其他所需资源

  5. 日志记录:启用CloudWatch日志以监控训练和推理过程

总结

aws-sdk-sagemaker为Rust开发者提供了强大的工具来集成Amazon SageMaker服务。通过这个库,你可以完全用Rust代码管理机器学习工作流,从数据准备、模型训练到部署推理端点。Rust的类型安全和性能特性使其成为构建可靠、高效的机器学习管道的优秀选择。

回到顶部