Rust高性能BLAS绑定库blas-sys的使用,blas-sys为Rust提供基础线性代数子程序(BLAS)的底层系统接口

Rust高性能BLAS绑定库blas-sys的使用

blas-sys为Rust提供基础线性代数子程序(BLAS)的底层系统接口。

安装

在项目目录中运行以下Cargo命令:

cargo add blas-sys

或者在Cargo.toml中添加以下行:

blas-sys = "0.9.0"

示例使用

下面是一个使用blas-sys进行矩阵向量乘法的完整示例:

use blas_sys::{cblas_dgemv, CblasColMajor, CblasNoTrans};

fn main() {
    // 定义矩阵和向量
    let matrix = vec![1.0, 2.0, 3.0, 4.0]; // 2x2矩阵
    let vector = vec![5.0, 6.0]; // 2维向量
    let mut result = vec![0.0; 2]; // 结果向量
    
    // 矩阵向量乘法参数
    let m = 2; // 行数
    let n = 2; // 列数
    let alpha = 1.0; // 缩放因子
    let lda = 2; // 矩阵A的leading dimension
    let beta = 0.0; // 结果向量缩放因子
    
    unsafe {
        cblas_dgemv(
            CblasColMajor, // 列主序存储
            CblasNoTrans,  // 不转置矩阵
            m, n,          // 矩阵维度
            alpha,         // 缩放因子
            matrix.as_ptr(), // 矩阵数据
            lda,           // leading dimension
            vector.as_ptr(), // 向量数据
            1,             // 向量x的增量
            beta,          // 结果向量缩放因子
            result.as_mut_ptr(), // 结果向量
            1,             // 向量y的增量
        );
    }
    
    println!("Result: {:?}", result); // 应该输出 [17.0, 39.0]
}

完整示例

以下是一个更完整的示例,展示如何使用blas-sys进行矩阵乘法运算:

use blas_sys::{cblas_dgemm, CblasColMajor, CblasNoTrans, CblasTrans};

fn main() {
    // 定义两个矩阵
    let matrix_a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2矩阵
    let matrix_b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2矩阵
    let mut result = vec![0.0; 4]; // 结果矩阵(2x2)
    
    // 矩阵乘法参数
    let m = 2; // 矩阵A的行数
    let n = 2; // 矩阵B的列数
    let k = 2; // 矩阵A的列数/矩阵B的行数
    let alpha = 1.0; // 缩放因子
    let beta = 0.0; // 结果矩阵缩放因子
    
    unsafe {
        cblas_dgemm(
            CblasColMajor, // 列主序存储
            CblasNoTrans,  // 不转置矩阵A
            CblasNoTrans,  // 不转置矩阵B
            m, n, k,       // 矩阵维度
            alpha,         // 缩放因子
            matrix_a.as_ptr(), // 矩阵A数据
            m,            // 矩阵A的leading dimension
            matrix_b.as_ptr(), // 矩阵B数据
            k,            // 矩阵B的leading dimension
            beta,         // 结果矩阵缩放因子
            result.as_mut_ptr(), // 结果矩阵
            m,            // 结果矩阵的leading dimension
        );
    }
    
    println!("Matrix multiplication result: {:?}", result); 
    // 应该输出 [19.0, 22.0, 43.0, 50.0]
}

许可证

blas-sys采用以下许可证之一:

  • Apache-2.0 OR MIT

贡献

我们非常欢迎您的贡献。请不要犹豫,提出问题或提交拉取请求。任何提交的项目贡献都将根据LICENSE.md中给出的条款进行许可。


1 回复

Rust高性能BLAS绑定库blas-sys使用指南

简介

blas-sys是Rust语言的一个基础线性代数子程序(BLAS)的底层系统接口绑定库。它提供了对BLAS库的直接FFI绑定,允许Rust程序调用高性能的线性代数运算函数。

BLAS(Basic Linear Algebra Subprograms)是线性代数运算的标准接口,被广泛用于科学计算和高性能计算领域。blas-sys作为底层绑定,通常被其他高级线性代数库(如ndarray等)所使用。

安装

Cargo.toml中添加依赖:

[dependencies]
blas-sys = "0.7"

注意:blas-sys本身不包含BLAS实现,你需要确保系统已安装以下任一BLAS实现:

  • OpenBLAS
  • Intel MKL
  • Apple Accelerate Framework (macOS)
  • Netlib BLAS

基本使用方法

1. 向量点积 (sdot)

extern crate blas_sys;

fn main() {
    unsafe {
        let n = 3;
        let x = vec![1.0, 2.0, 3.0];
        let incx = 1;
        let y = vec![4.0, 5.极,5.0, 6.0];
        let incy = 1;
        
        let result = blas_sys::sdot(
            &n, 
            x.as_ptr(), 
            &incx, 
            y.as_ptr(), 
            &incy
        );
        
        println!("Dot product: {}", result); // 输出: 32.0 (1*4 + 2*5 + 3*6)
    }
}

2. 矩阵-向量乘法 (sgemv)

extern crate blas_sys;

fn main() {
    unsafe {
        let m = 2; // 行数
        let n = 3; // 列数
        let alpha = 1.0;
        let a = vec![1.0, 2.0,  // 列主序矩阵
                     3.0, 4.0,
                     5.0, 6.0];
        let lda = 2; // 矩阵的前导维度
        let x = vec![1.0, 2.0, 3.0];
        let incx = 1;
        let beta = 0.0;
        let mut y = vec![0.0, 0.0];
        let incy = 1;
        
        blas_sys::sgemv(
            b'N',  // 不转置
            &m,
            &n,
            &alpha,
            a.as_ptr(),
            &lda,
            x.as_ptr(),
            &incx,
            &beta,
            y.as_mut_ptr(),
            &incy
        );
        
        println!("Result: {:?}", y); // 输出: [22.0, 28.0]
    }
}

3. 矩阵-矩阵乘法 (sgemm)

extern crate blas_sys;

fn main() {
    unsafe {
        let m = 2; // A的行数,C的行数
        let n = 2; // B的列数,C的列数
        let k极 3; // A的列数,B的行数
        let alpha = 1.0;
        let a = vec![1.0, 2.0,  // 列主序矩阵A
                    3.0, 4.0,
                    5.0, 6.0];
        let lda = 2; // A的前导维度
        let b = vec![1.0, 2.0, 3.0,  // 列主序矩阵B
                     4.0, 5.0, 6.0];
        let ldb = 3; // B的前导维度
        let beta = 0.0;
        let mut c = vec![0.0, 0.0, 0.0, 0.0]; // 结果矩阵C
        let ldc = 2; // C的前导维度
        
        blas_sys::sgemm(
            b'N',  // A不转置
            b'N',  // B不转置
            &m,
            &n,
            &k,
            &alpha,
            a.as_ptr(),
            &lda,
            b.as_ptr(),
            &ldb,
            &beta,
            c.as_mut_ptr(),
            &ldc
        );
        
        println!("Result matrix: {:?}", c); // 输出: [22.0, 28.0, 49.0, 64.0]
    }
}

常用函数分类

blas-sys提供了BLAS的三个级别函数:

1. Level 1 (向量运算)

  • sdot/ddot: 向量点积
  • saxpy/daxpy: 向量加法
  • scopy/dcopy: 向量复制
  • snrm2/dnrm2: 向量欧几里得范数

2. Level 2 (矩阵-向量运算)

  • sgemv/dgemv: 通用矩阵-向量乘法
  • ssymv/dsymv: 对称矩阵-向量乘法
  • strmv/dtrmv: 三角矩阵-向量乘法

3. Level 3 (矩阵-矩阵运算)

  • sgemm/dgemm: 通用矩阵乘法
  • ssymm/dsymm: 对称矩阵乘法
  • strmm/dtrmm: 三角矩阵乘法

注意事项

  1. 所有函数调用都需要在unsafe块中进行,因为它们直接调用外部C函数
  2. 矩阵使用列主序(column-major)存储
  3. 函数命名前缀表示数据类型:
    • s: 单精度浮点数(f32)
    • d: 双精度浮点数(f64)
    • c: 单精度复数
    • z: 双精度复数
  4. 实际使用时,建议考虑更高级的封装库(如ndarray + ndarray-linalg),除非你需要直接控制BLAS调用

性能建议

  1. 确保链接到优化的BLAS实现(如OpenBLAS或Intel MKL)
  2. 对于小型矩阵/向量,BLAS调用的开销可能超过计算本身
  3. 尽可能重用内存而不是频繁分配/释放
  4. 使用适当的转置参数避免不必要的内存访问模式

blas-sys为Rust提供了直接访问高性能线性代数运算的能力,是构建科学计算应用的重要基础组件。

回到顶部