Rust深度学习框架Candle-Transformers的使用:高效NLP与Transformer模型推理库
candle-transformers
元数据 包:cargo/candle-transformers@0.9.1 发布时间:4个月前 版本:2021版 许可证:MIT OR Apache-2.0 大小:361 KiB
安装 在项目目录中运行以下Cargo命令: cargo add candle-transformers
或者在Cargo.toml中添加以下行: candle-transformers = “0.9.1”
文档 docs.rs/candle-transformers/0.9.1
仓库 github.com/huggingface/candle
所有者 Laurent Mazare
类别 Science
以下是一个使用candle-transformers进行文本分类的完整示例:
use candle_core::{Device, Tensor, DType};
use candle_transformers::models::bert::{BertModel, Config};
use tokenizers::Tokenizer;
fn main() -> anyhow::Result<()> {
// 初始化设备(CPU或CUDA)
let device = Device::Cpu;
// 加载BERT模型配置
let config = Config::bert_base_uncased();
let model = BertModel::load(&config, "path/to/model")?;
// 加载分词器
let tokenizer = Tokenizer::from_pretrained("bert-base-uncased", None)?;
// 准备输入文本
let text = "This is a sample text for classification";
// 分词处理
let encoding = tokenizer.encode(text, true)?;
let input_ids = Tensor::new(encoding.get_ids(), &device)?.unsqueeze(0)?;
let attention_mask = Tensor::new(encoding.get_attention_mask(), &device)?.unsqueeze(0)?;
let token_type_ids = Tensor::new(encoding.get_type_ids(), &device)?.unsqueeze(0)?;
// 模型推理
let outputs = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
// 获取池化输出(用于分类任务)
let pooled_output = outputs.pooled_output;
// 添加分类层(这里需要根据具体任务定义分类器)
// let logits = classification_layer.forward(&pooled_output)?;
println!("Pooled output shape: {:?}", pooled_output.shape());
Ok(())
}
以下是一个情感分析任务的完整示例:
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module, VarBuilder};
use candle_transformers::models::bert::{BertModel, Config};
use tokenizers::Tokenizer;
struct SentimentClassifier {
bert: BertModel,
classifier: Linear,
}
impl SentimentClassifier {
fn new(vb: VarBuilder, config: Config) -> anyhow::Result<Self> {
let bert = BertModel::load(&config, vb.pp("bert"))?;
let classifier = candle_nn::linear(config.hidden_size, 2, vb.pp("classifier"))?;
Ok(Self { bert, classifier })
}
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> anyhow::Result<Tensor> {
let outputs = self.bert.forward(input_ids, attention_mask, &None)?;
let pooled_output = outputs.pooled_output;
self.classifier.forward(&pooled_output)
}
}
fn main() -> anyhow::Result<()> {
let device = Device::Cpu;
let config = Config::bert_base_uncased();
let vb = VarBuilder::zeros(DType::F32, &device);
let model = SentimentClassifier::new(vb, config)?;
let tokenizer = Tokenizer::from_pretrained("bert-base-uncased", None)?;
// 示例文本
let texts = vec![
"I love this movie!",
"This is terrible.",
];
for text in texts {
let encoding = tokenizer.encode(text, true)?;
let input_ids = Tensor::new(encoding.get_ids(), &device)?.unsqueeze(0)?;
let attention_mask = Tensor::new(encoding.get_attention_mask(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input_ids, &attention_mask)?;
let probs = logits.softmax(1)?;
println!("Text: {}", text);
println!("Sentiment probabilities: {:?}", probs);
}
Ok(())
}
1 回复
Rust深度学习框架Candle-Transformers使用指南
框架简介
Candle-Transformers是基于Rust语言开发的高性能Transformer模型推理库,专注于自然语言处理任务。该框架构建在Candle深度学习框架之上,提供了高效的模型推理能力和简洁的API接口。
核心特性
- 支持多种预训练Transformer模型(BERT、GPT、T5等)
- 零拷贝张量操作,内存效率高
- 支持CPU和GPU加速
- 无运行时开销,编译时优化
- 线程安全的API设计
安装方法
在Cargo.toml中添加依赖:
[dependencies]
candle = "0.3"
candle-transformers = "0.1"
基础使用示例
1. 文本分类
use candle_transformers::models::bert::{BertModel, Config};
use candle_core::{Device, Tensor};
fn main() -> anyhow::Result<()> {
let device = Device::Cpu;
let config = Config::bert_base();
let model = BertModel::load(&config, "path/to/model")?;
let input_ids = Tensor::new(&[[101, 2023, 2003, 1037, 102]], &device)?;
let attention_mask = Tensor::new(&[[1, 1, 1, 1, 1]], &device)?;
let output = model.forward(&input_ids, &attention_mask)?;
println!("Output shape: {:?}", output.shape());
Ok(())
}
2. 文本生成
use candle_transformers::models::gpt2::{GPT2Model, GPT2Config};
use candle_core::{Device, Tensor};
fn generate_text() -> anyhow::Result<()> {
let device = Device::Cpu;
let config = GPT2Config::gpt2_medium();
let model = GPT2Model::load(&config, "path/to/gpt2_model")?;
let input_ids = Tensor::new(&[[50256]], &device)?; // 开始标记
let output = model.generate(&input_ids, 50, 0.9)?;
println!("Generated text: {:?}", output);
Ok(())
}
3. 句子相似度计算
use candle_transformers::models::sentence_transformers::SentenceTransformer;
use candle_core::Device;
fn compute_similarity() -> anyhow::Result<()> {
let device = Device::Cpu;
let model = SentenceTransformer::load("all-MiniLM-L6-v2", &device)?;
let sentences = vec![
"Rust is a systems programming language",
"Candle-Transformers provides efficient NLP inference"
];
let embeddings = model.encode(&sentences)?;
let similarity = embeddings[0].cosine_similarity(&embeddings[1])?;
println!("Similarity score: {:.4}", similarity);
Ok(())
}
高级功能
批量处理
use candle_transformers::pipelines::text_classification::TextClassificationPipeline;
fn batch_processing() -> anyhow::Result<()> {
let pipeline = TextClassificationPipeline::new("bert-base-uncased")?;
let texts = vec![
"I love programming in Rust",
"This framework is amazing",
"Natural language processing is fascinating"
];
let results = pipeline.predict_batch(&texts)?;
for (text, label, score) in results {
println!("Text: {} -> Label: {} (Score: {:.4})", text, label, score);
}
Ok(())
}
自定义模型配置
use candle_transformers::models::bert::Config;
use candle_core::Device;
fn custom_model() -> anyhow::Result<()> {
let mut config = Config::bert_base();
config.hidden_size = 768;
config.num_hidden_layers = 12;
config.num_attention_heads = 12;
let device = Device::Cpu;
let model = BertModel::new(&config, &device)?;
// 使用自定义配置的模型进行推理
Ok(())
}
性能优化技巧
- 使用GPU加速:
let device = Device::cuda_if_available(0)?;
- 批量处理:尽可能使用批量推理减少IO开销
- 模型量化:使用8位或16位精度减少内存使用
- 缓存机制:重复使用已加载的模型实例
注意事项
- 确保模型文件路径正确
- 注意输入张量的形状和数据类型
- 处理可能的内存分配错误
- 考虑使用异步处理提高吞吐量
这个框架特别适合需要高性能NLP推理的Rust应用程序,在保持Rust语言安全性和性能优势的同时,提供了便捷的深度学习模型接口。
完整示例Demo
以下是一个完整的文本分类示例,结合了上述内容中的多个功能:
use candle_transformers::models::bert::{BertModel, Config};
use candle_core::{Device, Tensor, DType};
use anyhow::Result;
/// 完整的文本分类示例
fn main() -> Result<()> {
// 设置设备 - 优先使用GPU,如果没有则使用CPU
let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
println!("Using device: {:?}", device);
// 加载BERT基础配置
let config = Config::bert_base();
// 加载预训练模型(需要提前下载模型文件)
let model = BertModel::load(&config, "./models/bert-base-uncased")?;
// 准备输入数据
// 假设输入文本经过tokenizer处理后的ID序列
let input_ids = Tensor::new(&[
[101, 2023, 2003, 1037, 3231, 102], // "I love programming"
[101, 2003, 1037, 2748, 102, 0] // "This is great" (填充0)
], &device)?;
// 注意力掩码
let attention_mask = Tensor::new(&[
[1, 1, 1, 1, 1, 1], // 第一个序列有效长度
[1, 1, 1, 1, 1, 0] // 第二个序列,最后一个位置是填充
], &device)?;
// 执行模型推理
let output = model.forward(&input_ids, &attention_mask)?;
// 输出结果形状
println!("Output shape: {:?}", output.shape());
// 假设这是一个二分类任务,取最后一个隐藏状态进行分类
let logits = output
.narrow(1, output.dim(1)? - 1, 1)? // 取序列最后一个token
.squeeze(1)?; // 去除序列维度
// 应用softmax获取概率分布
let probs = logits.softmax(1)?;
println!("Classification probabilities: {:?}", probs.to_vec2::<f32>());
Ok(())
}
/// 批量处理示例
fn batch_classification_example() -> Result<()> {
let device = Device::Cpu;
// 使用文本分类管道
let pipeline = TextClassificationPipeline::new("bert-base-uncased")?;
// 批量文本数据
let texts = vec![
"I really enjoy using Rust for machine learning",
"The weather is beautiful today",
"This movie is absolutely fantastic",
"The product quality is poor and disappointing"
];
// 批量预测
let results = pipeline.predict_batch(&texts)?;
// 输出结果
for (i, (text, label, score)) in results.iter().enumerate() {
println!("Example {}: '{}' -> {} (confidence: {:.3})",
i + 1, text, label, score);
}
Ok(())
}
/// 自定义配置示例
fn custom_config_example() -> Result<()> {
let device = Device::Cpu;
// 创建自定义配置
let mut config = Config::bert_base();
config.hidden_size = 512; // 减小隐藏层大小
config.num_hidden_layers = 6; // 减少层数
config.num_attention_heads = 8; // 减少注意力头数
config.intermediate_size = 2048; // 中间层大小
// 使用自定义配置创建模型
let model = BertModel::new(&config, &device)?;
println!("Custom model created successfully with config: {:?}", config);
Ok(())
}
这个完整示例展示了:
- 设备选择(自动检测GPU)
- 模型加载和配置
- 输入数据准备(包括填充处理)
- 模型推理和前向传播
- 结果处理和输出
- 批量处理功能
- 自定义模型配置
要运行此示例,需要先下载相应的预训练模型文件到指定路径,并确保所有依赖项正确安装。