Golang中的事务处理详解

Golang中的事务处理详解 我正在开发一个健身房桌面应用程序。随着开发的进行,应用程序在不断增长,我没有过多关注面向对象编程之类的东西,而是尽可能地保持简单并使其正常工作。现在我已经达到了我知识的极限,我已经完成了注册用户的交易处理。由于用户将注册客户,所以没有身份验证之类的功能。因此,我不知道我是否正确处理了我的交易,或者我可以改进什么。简而言之,如果您能给我一些反馈,那将对我有很大帮助。

type userRepository struct {
    storage *postgres.PgxStorage
}

func NewUserRepository(storage *postgres.PgxStorage) user.Repository {
    return &userRepository{storage: storage}
}

func (ur *userRepository) RegisterUser(ctx context.Context, register *entities.RegisterUsertx) (int32, error) {

    tx, err := ur.storage.DBPool.Begin(ctx)
    if err != nil {
        log.Printf("error beginning transaction: %v", err)
        return 0, fmt.Errorf("begin transaction: %w", err)
    }
    defer func() {
        if err != nil {
            tx.Rollback(ctx)
            log.Printf("transaction rolled back due to error: %v", err)
        }
    }()

    // Registrar usuario
    var userID int32
    query := "INSERT INTO users (name, lastname1, lastname2, email, phone, created_at) VALUES($1, $2, $3, $4, $5, $6) RETURNING id"
    err = tx.QueryRow(ctx, query, register.Name, register.Lastname1, register.Lastname2, register.Email, register.Phone, register.CreatedAt).Scan(&userID)
    if err != nil {
        log.Printf("error inserting user: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("insert user: %w", err)
    }
    register.ID = userID

    // Crear cuenta
    accountID := uuid.New()
    var account uuid.UUID
    query = "INSERT INTO accounts (user_id, account_id, account_type_id, created_at) VALUES($1, $2, $3, $4) RETURNING account_id"
    err = tx.QueryRow(ctx, query, userID, accountID, register.AccountTypeID, time.Now()).Scan(&account)
    if err != nil {
        log.Printf("error inserting account: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("insert account: %w", err)
    }

    // Obtener la duracion mediante el id de la suscripcion
    var subscriptionDuration int
    query = "SELECT subscription_day FROM subscription_costs WHERE id = $1"
    err = tx.QueryRow(ctx, query, register.SubscriptionCostID).Scan(&subscriptionDuration)
    if err != nil {
        log.Printf("error getting subscription duration: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("get subscription duration: %w", err)
    }

    startDate := time.Now()
    endDate := startDate.AddDate(0, 0, subscriptionDuration)

    // Crear subscripcion
    query = "INSERT INTO subscriptions (account_id, subscription_cost_id, start_date, end_date) VALUES($1, $2, $3, $4) RETURNING id"
    err = tx.QueryRow(ctx, query, account, register.SubscriptionCostID, startDate, endDate).Scan(&register.SubscriptionID)
    if err != nil {
        log.Printf("error inserting subscription: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("insert subscription: %w", err)
    }

    // Obtener el monto y comparar con el monto de la suscripcion
    var expectedCost float64
    query = "SELECT cost from subscription_costs where id = $1"
    err = tx.QueryRow(ctx, query, register.SubscriptionCostID).Scan(&expectedCost)
    if err != nil {
        log.Printf("error getting subscription cost: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("ammount: %w", err)
    }

    if register.Ammount != expectedCost {
        log.Printf("amount incorrect: expected %v, got %v", expectedCost, register.Ammount)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("amount incorrect: %w", err)
    }

    query = "INSERT INTO payments (account_id, payment_type_id, cost, payment_date) VALUES($1, $2, $3, $4)"
    _, err = tx.Exec(ctx, query, account, register.PaymentTypeID, register.Ammount, time.Now())
    if err != nil {
        log.Printf("error inserting payment: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("insert payment: %w", err)
    }

    // Obtener el id del status
    var statusPayment, statusAcccount int32

    query = `SELECT id FROM STATUS WHERE id = 5 OR id = 1`
    rows, err := tx.Query(ctx, query)
    if err != nil {
        log.Printf("error getting statuses: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("get statuses: %w", err)
    }
    defer rows.Close()

    for rows.Next() {
        var id int32
        if err := rows.Scan(&id); err != nil {
            log.Printf("error scanning status: %v", err)
            tx.Rollback(ctx)
            return 0, fmt.Errorf("scan status: %w", err)
        }

        if id == 5 {
            statusPayment = id
        } else {
            statusAcccount = id
        }
    }
    if err := rows.Err(); err != nil {
        log.Printf("rows error: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("rows error: %w", err)
    }

    query = `SELECT id from status where id = 5`
    err = tx.QueryRow(ctx, query).Scan(&statusPayment)
    if err != nil {
        log.Printf("error getting status: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("get status: %w", err)
    }

    query = "UPDATE payments SET status_id = $1 WHERE account_id = $2"
    _, err = tx.Exec(ctx, query, statusPayment, &account)
    if err != nil {
        log.Printf("error updating payment status: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("insert status: %w", err)
    }

    query = "UPDATE accounts SET subscription_id = $1, status_id = $2 WHERE account_id = $3"
    _, err = tx.Exec(ctx, query, &register.SubscriptionID, statusAcccount, &account)
    if err != nil {
        log.Printf("error updating account status: %v", err)
        tx.Rollback(ctx)
        return 0, fmt.Errorf("get status: %w", err)
    }

    err = tx.Commit(ctx)
    if err != nil {
        log.Printf("error committing transaction: %v", err)
        return 0, fmt.Errorf("commit transaction: %w", err)
    }

    return userID, nil
}

更多关于Golang中的事务处理详解的实战教程也可以访问 https://www.itying.com/category-94-b0.html

2 回复

首先,你需要知道何时应该使用事务,例如:

  1. 批量操作:当执行一批插入、更新或删除操作,并且这些操作只有在全部成功时才应成功时。
  2. 复杂操作:当单个逻辑操作涉及多个数据库操作时。
  3. 一致性保证:当你需要确保跨多个表或操作的数据一致性时。

在你的示例中,有些只是查询,你不需要使用事务。 然后你应该知道如何使用它。 只需开始事务。 如果某些操作失败, 你就调用回滚。 从你的代码来看,你已经知道如何使用它了。

更多关于Golang中的事务处理详解的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


你的代码在事务处理方面有几个可以改进的地方。以下是具体的优化建议和示例代码:

1. 事务错误处理优化

当前代码中多次调用 tx.Rollback(ctx),但 defer 中的回滚逻辑可能不会按预期工作:

func (ur *userRepository) RegisterUser(ctx context.Context, register *entities.RegisterUsertx) (int32, error) {
    tx, err := ur.storage.DBPool.Begin(ctx)
    if err != nil {
        log.Printf("error beginning transaction: %v", err)
        return 0, fmt.Errorf("begin transaction: %w", err)
    }
    
    // 使用局部变量跟踪错误,确保defer正确执行
    var txErr error
    defer func() {
        if txErr != nil {
            if rbErr := tx.Rollback(ctx); rbErr != nil {
                log.Printf("rollback error: %v", rbErr)
            }
        }
    }()
    
    // 后续操作使用 txErr 而不是 err
    var userID int32
    query := "INSERT INTO users (name, lastname1, lastname2, email, phone, created_at) VALUES($1, $2, $3, $4, $5, $6) RETURNING id"
    txErr = tx.QueryRow(ctx, query, register.Name, register.Lastname1, register.Lastname2, register.Email, register.Phone, register.CreatedAt).Scan(&userID)
    if txErr != nil {
        log.Printf("error inserting user: %v", txErr)
        return 0, fmt.Errorf("insert user: %w", txErr)
    }
    
    // 移除所有手动调用 tx.Rollback(ctx) 的代码
    // 只返回错误,让defer处理回滚
    
    // 提交事务
    if commitErr := tx.Commit(ctx); commitErr != nil {
        log.Printf("error committing transaction: %v", commitErr)
        return 0, fmt.Errorf("commit transaction: %w", commitErr)
    }
    
    return userID, nil
}

2. 使用事务上下文

确保使用事务上下文进行所有数据库操作:

func (ur *userRepository) RegisterUser(ctx context.Context, register *entities.RegisterUsertx) (int32, error) {
    // 开始事务时传递上下文
    tx, err := ur.storage.DBPool.BeginTx(ctx, pgx.TxOptions{})
    if err != nil {
        return 0, fmt.Errorf("begin transaction: %w", err)
    }
    
    defer func() {
        if err != nil {
            tx.Rollback(ctx)
        }
    }()
    
    // 所有数据库操作都使用事务的上下文
    var userID int32
    err = tx.QueryRow(ctx, 
        "INSERT INTO users (...) VALUES(...) RETURNING id",
        register.Name, register.Lastname1, register.Lastname2, 
        register.Email, register.Phone, register.CreatedAt,
    ).Scan(&userID)
    
    if err != nil {
        return 0, fmt.Errorf("insert user: %w", err)
    }
    
    // 继续其他操作...
    
    if err := tx.Commit(ctx); err != nil {
        return 0, fmt.Errorf("commit transaction: %w", err)
    }
    
    return userID, nil
}

3. 查询优化

你的代码中有重复查询和可以优化的地方:

// 优化前:两次查询 subscription_costs 表
query = "SELECT subscription_day FROM subscription_costs WHERE id = $1"
err = tx.QueryRow(ctx, query, register.SubscriptionCostID).Scan(&subscriptionDuration)

// ... 后面又查询
query = "SELECT cost from subscription_costs where id = $1"
err = tx.QueryRow(ctx, query, register.SubscriptionCostID).Scan(&expectedCost)

// 优化后:一次查询获取所有需要的数据
var subscriptionDuration int
var expectedCost float64
query = "SELECT subscription_day, cost FROM subscription_costs WHERE id = $1"
err = tx.QueryRow(ctx, query, register.SubscriptionCostID).Scan(&subscriptionDuration, &expectedCost)
if err != nil {
    log.Printf("error getting subscription details: %v", err)
    return 0, fmt.Errorf("get subscription details: %w", err)
}

4. 状态查询优化

你的状态查询逻辑可以简化:

// 优化前:复杂的查询和循环
query = `SELECT id FROM STATUS WHERE id = 5 OR id = 1`
rows, err := tx.Query(ctx, query)
// ... 循环处理

// 优化后:直接查询需要的状态
var statusPayment, statusAccount int32

// 查询支付状态
err = tx.QueryRow(ctx, "SELECT id FROM status WHERE id = 5").Scan(&statusPayment)
if err != nil {
    return 0, fmt.Errorf("get payment status: %w", err)
}

// 查询账户状态
err = tx.QueryRow(ctx, "SELECT id FROM status WHERE id = 1").Scan(&statusAccount)
if err != nil {
    return 0, fmt.Errorf("get account status: %w", err)
}

5. 完整优化示例

func (ur *userRepository) RegisterUser(ctx context.Context, register *entities.RegisterUsertx) (int32, error) {
    tx, err := ur.storage.DBPool.BeginTx(ctx, pgx.TxOptions{})
    if err != nil {
        return 0, fmt.Errorf("begin transaction: %w", err)
    }
    
    defer func() {
        if err != nil {
            tx.Rollback(ctx)
        }
    }()
    
    // 1. 插入用户
    var userID int32
    err = tx.QueryRow(ctx,
        "INSERT INTO users (name, lastname1, lastname2, email, phone, created_at) VALUES($1, $2, $3, $4, $5, $6) RETURNING id",
        register.Name, register.Lastname1, register.Lastname2, register.Email, register.Phone, register.CreatedAt,
    ).Scan(&userID)
    if err != nil {
        return 0, fmt.Errorf("insert user: %w", err)
    }
    register.ID = userID
    
    // 2. 创建账户
    accountID := uuid.New()
    var account uuid.UUID
    err = tx.QueryRow(ctx,
        "INSERT INTO accounts (user_id, account_id, account_type_id, created_at) VALUES($1, $2, $3, $4) RETURNING account_id",
        userID, accountID, register.AccountTypeID, time.Now(),
    ).Scan(&account)
    if err != nil {
        return 0, fmt.Errorf("insert account: %w", err)
    }
    
    // 3. 获取订阅详情(一次查询)
    var subscriptionDuration int
    var expectedCost float64
    err = tx.QueryRow(ctx,
        "SELECT subscription_day, cost FROM subscription_costs WHERE id = $1",
        register.SubscriptionCostID,
    ).Scan(&subscriptionDuration, &expectedCost)
    if err != nil {
        return 0, fmt.Errorf("get subscription details: %w", err)
    }
    
    // 4. 验证金额
    if register.Ammount != expectedCost {
        return 0, fmt.Errorf("amount incorrect: expected %v, got %v", expectedCost, register.Ammount)
    }
    
    // 5. 创建订阅
    startDate := time.Now()
    endDate := startDate.AddDate(0, 0, subscriptionDuration)
    err = tx.QueryRow(ctx,
        "INSERT INTO subscriptions (account_id, subscription_cost_id, start_date, end_date) VALUES($1, $2, $3, $4) RETURNING id",
        account, register.SubscriptionCostID, startDate, endDate,
    ).Scan(&register.SubscriptionID)
    if err != nil {
        return 0, fmt.Errorf("insert subscription: %w", err)
    }
    
    // 6. 插入支付记录
    _, err = tx.Exec(ctx,
        "INSERT INTO payments (account_id, payment_type_id, cost, payment_date) VALUES($1, $2, $3, $4)",
        account, register.PaymentTypeID, register.Ammount, time.Now(),
    )
    if err != nil {
        return 0, fmt.Errorf("insert payment: %w", err)
    }
    
    // 7. 更新支付状态
    _, err = tx.Exec(ctx,
        "UPDATE payments SET status_id = 5 WHERE account_id = $1",
        account,
    )
    if err != nil {
        return 0, fmt.Errorf("update payment status: %w", err)
    }
    
    // 8. 更新账户状态和订阅ID
    _, err = tx.Exec(ctx,
        "UPDATE accounts SET subscription_id = $1, status_id = 1 WHERE account_id = $2",
        register.SubscriptionID, account,
    )
    if err != nil {
        return 0, fmt.Errorf("update account: %w", err)
    }
    
    // 提交事务
    if err = tx.Commit(ctx); err != nil {
        return 0, fmt.Errorf("commit transaction: %w", err)
    }
    
    return userID, nil
}

主要改进点:

  1. 统一错误处理,避免重复回滚调用
  2. 优化查询,减少数据库往返次数
  3. 简化状态查询逻辑
  4. 使用事务上下文确保一致性
  5. 代码结构更清晰,易于维护
回到顶部