Rust PostgreSQL协议库pgwire的使用,实现高性能PostgreSQL客户端与服务器通信

Rust PostgreSQL协议库pgwire的使用,实现高性能PostgreSQL客户端与服务器通信

pgwire是一个Rust库,实现了PostgreSQL Wire Protocol,提供了编写PostgreSQL兼容服务器和客户端的基本API。它类似于hyper,但是专门针对PostgreSQL协议。

项目状态

pgwire已经实现了以下功能:

  • 前端-后端协议消息(3.0和3.2版本)
  • 基于Tokio的后端TCP/TLS服务器
  • 多种认证方式(无认证、明文密码、MD5密码、SASL SCRAM认证)
  • 简单查询和扩展查询协议
  • 多种数据类型支持(文本格式和二进制格式)

PostgreSQL Wire协议概述

PostgreSQL Wire Protocol是一个相对通用的第7层协议,包含6个部分:

  1. 启动:客户端-服务器握手和认证
  2. 简单查询:基于文本的查询协议
  3. 扩展查询:支持服务器端查询缓存和参数重用的子协议
  4. 复制:用于从PostgreSQL复制数据的子协议
  5. 逻辑复制
  6. 数据复制

完整示例代码

服务器端完整示例

use pgwire::api::{
    auth::noop::NoopStartupHandler, 
    query::PlaceholderExtendedQueryHandler,
    AuthSource, MakeHandler, ServerParameterProvider
};
use pgwire::error::PgWireResult;
use pgwire::server::{Server, ServerConfiguration};
use std::collections::HashMap;
use std::sync::Arc;

// 服务器参数提供者
struct DummyParameterProvider;
impl ServerParameterProvider for DummyParameterProvider {
    fn server_parameters<C>(&self, _: &C) -> Option<HashMap<String, String>> {
        let mut params = HashMap::new();
        params.insert("server_version".to_owned(), "14.0".to_owned());
        params.insert("client_encoding".to_owned(), "UTF8".to_owned());
        Some(params)
    }
}

// 认证处理器
struct DummyAuth;
impl AuthSource for DummyAuth {
    fn get_password(&self, username: &str) -> PgWireResult<String> {
        Ok(format!("password_for_{}", username))
    }
}

#[tokio::main]
async fn main() -> PgWireResult<()> {
    // 创建查询处理器
    let processor = Arc::new(PlaceholderExtendedQueryHandler::new());
    let processor = MakeHandler::new(processor);
    
    // 创建启动处理器
    let startup_handler = Arc::new(NoopStartupHandler::new());
    
    // 配置服务器
    let config = ServerConfiguration {
        parameter_provider: Arc::new(DummyParameterProvider),
        auth_source: Arc::new(DummyAuth),
        startup_handler,
        query_handler: processor,
        ..Default::default()
    };

    // 创建并启动服务器
    let server = Server::new(config);
    println!("PostgreSQL 服务器运行在 127.0.0.1:5432");
    server.run_on_tcp("127.0.0.1:5432").await?;
    
    Ok(())
}

客户端完整示例

use pgwire::client::{Client, ClientConfig};
use pgwire::error::PgWireResult;
use pgwire::messages::response::FieldInfo;
use std::time::Instant;

#[tokio::main]
async fn main() -> PgWireResult<()> {
    // 配置客户端
    let config = ClientConfig {
        username: "test_user".to_owned(),
        password: Some("password_for_test_user".to_owned()),
        database: Some("test_db".to_owned()),
        ..Default::default()
    };

    println!("正在连接到 PostgreSQL 服务器...");
    let start_time = Instant::now();
    
    // 创建并连接客户端
    let mut client = Client::connect("127.0.0.1:5432", config).await?;
    
    println!("连接成功! 耗时: {:?}", start_time.elapsed());
    
    // 执行简单查询
    println!("执行简单查询: SELECT version()");
    let result = client.simple_query("SELECT version()").await?;
    
    // 处理查询结果
    if let Some(row) = result.rows.get(0) {
        println!("服务器版本: {}", row.get::<String>(0).unwrap());
    }
    
    // 执行带参数的扩展查询
    println!("执行扩展查询: SELECT $1::TEXT");
    let extended = client.prepare("SELECT $1::TEXT", &[]).await?;
    let result = client.execute(&extended, &[&"Hello pgwire"]).await?;
    
    if let Some(row) = result.rows.get(0) {
        println!("查询结果: {}", row.get::<String>(0).unwrap());
    }
    
    println!("关闭连接...");
    client.close().await?;
    
    Ok(())
}

项目使用案例

pgwire已经被多个项目使用,包括:

  • GreptimeDB:云原生时序数据库
  • risinglight:用于教育目的的OLAP数据库系统
  • PeerDB:PostgreSQL优先的ETL/ELT工具
  • CeresDB:来自AntGroup的高性能分布式时序数据库
  • dozer:实时数据平台
  • restate:构建弹性工作流的框架

许可证

pgwire采用MIT/Apache双许可证发布。


1 回复

Rust PostgreSQL协议库pgwire的使用指南

pgwire是一个用于实现PostgreSQL协议通信的Rust库,它允许开发者构建高性能的PostgreSQL客户端和服务器组件。

主要特性

  • 实现了PostgreSQL的前后端协议(3.0版本)
  • 支持异步I/O(tokio兼容)
  • 纯Rust实现,无外部依赖
  • 提供客户端和服务器端构建块

基本使用方法

添加依赖

首先在Cargo.toml中添加依赖:

[dependencies]
pgwire = "0.6"
tokio = { version = "1.0", features = ["full"] }

构建简单PostgreSQL服务器

use pgwire::api::{auth::noop::NoopStartupHandler, MakeHandler, StatelessMakeHandler};
use pgwire::error::PgWireResult;
use pgwire::tokio::process_socket;

#[tokio::main]
pub async fn main() -> PgWireResult<()> {
    let processor = StatelessMakeHandler::new(Arc::new(NoopBackend::new()));
    let authenticator = Arc::new(NoopStartupHandler);
    
    let server_addr = "127.0.0.1:5432";
    let listener = TcpListener::bind(server_addr).await?;
    println!("Listening to {}", server_addr);
    
    loop {
        let incoming_socket = listener.accept().await?.0;
        let authenticator_ref = authenticator.clone();
        let processor_ref = processor.make();
        
        tokio::spawn(async move {
            process_socket(
                incoming_socket,
                None,
                authenticator_ref,
                processor_ref,
            )
            .await
        });
    }
}

实现查询处理器

use pgwire::api::{ClientInfo, QueryHandler, Response, StatementOrPortal};
use pgwire::error::PgWireResult;
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::PgWireBackendMessage;
use std::sync::Arc;

struct MyQueryHandler;

#[async_trait::async_trait]
impl QueryHandler for MyQueryHandler {
    async fn do_query<'a>(
        &self,
        _client: &mut ClientInfo,
        query: &'a str,
        _stmt: StatementOrPortal<'a>,
    ) -> PgWireResult<Response> {
        if query == "SELECT 1" {
            Ok(Response::Select {
                schema: None,
                results: vec![vec![Some("1".to_owned())]],
            })
        } else {
            Ok(Response::Error(ErrorResponse::error(
                "ERROR".to_owned(),
                "42P01".to_owned(),
                "Unsupported query".to_owned(),
            )))
        }
    }
}

构建PostgreSQL客户端

use pgwire::tokio_postgres::{Client, NoTls};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 连接到服务器
    let (client, connection) = pgwire::tokio_postgres::connect(
        "host=localhost user=postgres dbname=postgres",
        NoTls,
    )
    .await?;

    // 处理连接任务
    tokio::spawn(async move {
        if let Err(e) = connection.await {
            eprintln!("connection error: {}", e);
        }
    });

    // 执行查询
    let rows = client.query("SELECT 1", &[]).await?;
    for row in rows {
        let value: i32 = row.get(0);
        println!("value: {}", value);
    }

    Ok(())
}

高级功能

自定义认证

use pgwire::api::auth::{AuthSource, ServerParameterProvider, StartupHandler};
use pgwire::error::{PgWireError, PgWireResult};

struct MyAuthHandler;

#[async_trait::async_trait]
impl StartupHandler for MyAuthHandler {
    async fn on startup<C>(
        &self,
        client: &mut C,
        _: AuthSource,
    ) -> PgWireResult<Box<dyn ServerParameterProvider>>
    where
        C: ClientInfo + Unpin + Send + Sync,
    {
        // 实现自定义认证逻辑
        if client.user().as_deref() != Some("admin") {
            Err(PgWireError::UserError(
                "Only admin user is allowed".to_owned(),
            ))
        } else {
            Ok(Box::new(HashMap::new()))
        }
    }
}

处理参数化查询

use pgwire::api::{ClientInfo, QueryHandler, Response, StatementOrPortal};

#[async_trait::async_trait]
impl QueryHandler for MyQueryHandler {
    async fn do_query<'a>(
        &self,
        _client: &mut ClientInfo,
        query: &'a str,
        stmt: StatementOrPortal<'a>,
    ) -> PgWireResult<Response> {
        match stmt {
            StatementOrPortal::Statement(stmt) => {
                println!("Preparing statement: {}", stmt.statement());
                Ok(Response::BindComplete)
            }
            StatementOrPortal::Portal(portal) => {
                println!("Executing portal with params: {:?}", portal.parameters());
                Ok(Response::ExecuteComplete(1))
            }
        }
    }
}

性能优化建议

  1. 使用连接池管理客户端连接
  2. 对于高吞吐量场景,考虑使用Arc共享查询处理器
  3. 实现批量查询处理以减少上下文切换
  4. 使用tokio-postgres的异步流处理大结果集

pgwire提供了构建PostgreSQL兼容服务器和客户端所需的基础组件,开发者可以根据需要扩展功能或优化性能。

完整示例代码

下面是一个完整的PostgreSQL服务器实现示例,包含自定义认证和查询处理:

use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use pgwire::api::auth::{AuthSource, ServerParameterProvider, StartupHandler};
use pgwire::api::{ClientInfo, MakeHandler, QueryHandler, Response, StatementOrPortal};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::tokio::process_socket;

// 自定义认证处理器
struct MyAuthHandler;

#[async_trait::async_trait]
impl StartupHandler for MyAuthHandler {
    async fn on_startup<C>(
        &self,
        client: &mut C,
        _: AuthSource,
    ) -> PgWireResult<Box<dyn ServerParameterProvider>>
    where
        C: ClientInfo + Unpin + Send + Sync,
    {
        // 只允许admin用户连接
        if client.user().as_deref() != Some("admin") {
            Err(PgWireError::UserError(
                "Authentication failed: only admin user is allowed".to_owned(),
            ))
        } else {
            Ok(Box::new(HashMap::new()))
        }
    }
}

// 自定义查询处理器
struct MyQueryHandler;

#[async_trait::async_trait]
impl QueryHandler for MyQueryHandler {
    async fn do_query<'a>(
        &self,
        _client: &mut ClientInfo,
        query: &'a str,
        stmt: StatementOrPortal<'a>,
    ) -> PgWireResult<Response> {
        match stmt {
            StatementOrPortal::Statement(_) => {
                // 处理简单查询
                if query == "SELECT 1" {
                    Ok(Response::Select {
                        schema: None,
                        results: vec![vec![Some("1".to_owned())]],
                    })
                } else if query.starts_with("SELECT") {
                    // 模拟返回多行数据
                    Ok(Response::Select {
                        schema: None,
                        results: vec![
                            vec![Some("1".to_owned()), Some("Alice".to_owned())],
                            vec![Some("2".to_owned()), Some("Bob".to_owned())],
                        ],
                    })
                } else {
                    Ok(Response::Error(ErrorResponse::error(
                        "ERROR".to_owned(),
                        "42P01".to_owned(),
                        format!("Unsupported query: {}", query),
                    )))
                }
            }
            StatementOrPortal::Portal(portal) => {
                // 处理参数化查询
                println!("Executing portal with params: {:?}", portal.parameters());
                Ok(Response::ExecuteComplete(1))
            }
        }
    }
}

#[tokio::main]
async fn main() -> PgWireResult<()> {
    // 创建查询处理器
    let processor = StatelessMakeHandler::new(Arc::new(MyQueryHandler));
    // 创建认证处理器
    let authenticator = Arc::new(MyAuthHandler);
    
    // 绑定到本地5432端口
    let server_addr = "127.0.0.1:5432";
    let listener = TcpListener::bind(server_addr).await?;
    println!("PostgreSQL server listening on {}", server_addr);
    
    // 接受连接循环
    loop {
        let incoming_socket = listener.accept().await?.0;
        let authenticator_ref = authenticator.clone();
        let processor_ref = processor.make();
        
        // 为每个连接生成新任务
        tokio::spawn(async move {
            if let Err(e) = process_socket(
                incoming_socket,
                None,
                authenticator_ref,
                processor_ref,
            ).await {
                eprintln!("Connection error: {}", e);
            }
        });
    }
}

要测试这个服务器,可以使用任何PostgreSQL客户端工具连接,用户名为"admin",不需要密码。支持的查询包括简单的"SELECT 1"和其他以"SELECT"开头的查询。

对于更复杂的实现,您可以扩展MyQueryHandler来处理更多的SQL命令,或者实现更精细的权限控制和结果集处理。

回到顶部