Golang中如何为处理器和Postgres编写测试

Golang中如何为处理器和Postgres编写测试 我正在尝试为这段代码编写一些测试:

// run docker run --name mypostgres -p 5432:5432 -e POSTGRES_PASSWORD=password -e POSTGRES_USER=postgres -d postgres:17.1

//Using net/http for http handlers without middleware
package main

import (
	"database/sql"
	"encoding/json"
	"fmt"
	"log"
	"net/http"

	"github.com/gorilla/mux"
	_ "github.com/lib/pq"
)

const (
	host     = "localhost"
	port     = 5432
	user     = "postgres"
	password = "password"
	dbname   = "mydb"
)

var db *sql.DB

type User struct {
	ID    int    `json:"id"`
	Name  string `json:"name"`
	Email string `json:"email"`
}

func main() {

	postgres_connection()
	defer db.Close()

	router := mux.NewRouter()
	routes(router)

	fmt.Println("Server is listening on port 8080")
	log.Fatal(http.ListenAndServe(":8080", router))
}

func routes(router mux.NewRouter) {

	router.HandleFunc("/users", authMiddleware(getUsers)).Methods("GET")
	router.HandleFunc("/users", authMiddleware(addUser)).Methods("POST")
	router.HandleFunc("/users/{id}", authMiddleware(updateUser)).Methods("PUT")
	router.HandleFunc("/users/{id}", authMiddleware(deleteUser)).Methods("DELETE")
}

func postgres_connection() {

	pgConnStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)

	conn, err := sql.Open("postgres", pgConnStr)
	if err != nil {
		log.Fatalf("Error opening database connection: %v", err)
	}
	db = conn
	err = db.Ping()
	if err != nil {
		log.Fatalf("Error connecting to the database: %v", err)
	}
	fmt.Println("Connected to the PostgreSQL database")
}

func getUsers(w http.ResponseWriter, r *http.Request) {
	rows, err := db.Query("SELECT id, name, email FROM users")
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	defer rows.Close()

	var users []User
	for rows.Next() {
		var user User
		err := rows.Scan(&user.ID, &user.Name, &user.Email)
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}
		users = append(users, user)
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(users)
}

func addUser(w http.ResponseWriter, r *http.Request) {
	var user User
	err := json.NewDecoder(r.Body).Decode(&user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	_, err = db.Exec("INSERT INTO users (name, email) VALUES ($1, $2)", user.Name, user.Email)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.WriteHeader(http.StatusCreated)
	fmt.Fprintf(w, "User added successfully")
}

func updateUser(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	id := vars["id"]

	var user User
	err := json.NewDecoder(r.Body).Decode(&user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	_, err = db.Exec("UPDATE users SET name=$1, email=$2 WHERE id=$3", user.Name, user.Email, id)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	fmt.Fprintf(w, "User updated successfully")
}

func deleteUser(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	id := vars["id"]

	_, err := db.Exec("DELETE FROM users WHERE id=$1", id)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	fmt.Fprintf(w, "User deleted successfully")
}

func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		token := r.Header.Get("Authorization")
		if token != "Bearer secret" {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}
		next.ServeHTTP(w, r)
	}
}

当涉及到处理程序(handlers)和PostgreSQL时,编写测试对我来说有点困难。我本想模拟数据库,但问题是我试图以错误的方式覆盖PostgreSQL数据库,所以当我运行测试时,会遇到一些“内存指针问题”。

以下是我尝试测试第一个处理程序的方法:

package main

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/gorilla/mux"
	"github.com/pashagolub/pgxmock"
)

func TestGetUsers(t *testing.T) {
	mockDB, err := pgxmock.NewConn()
	if err != nil {
		t.Fatalf("Error initializing mock database: %v", err)
	}
	defer mockDB.Close()

	mockDB.ExpectQuery("SELECT id, name, email FROM users").
		WillReturnRows(
			mockDB.NewRows([]string{"id", "name", "email"}).
				AddRow(1, "Laky", "laky@example.com").
				AddRow(2, "Luma", "luma@example.com"),
		)

	db = mockDB

	req := httptest.NewRequest("GET", "/users", nil)
	w := httptest.NewRecorder()

	router := mux.NewRouter()
	router.HandleFunc("/users", getUsers).Methods("GET")
	router.ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("Expected status OK, got %v", w.Code)
	}

	expected := `[{"id":1,"name":"Laky","email":"laky@example.com"},{"id":2,"name":"Luma","email":"luma@example.com"}]`
	if w.Body.String() != expected {
		t.Errorf("Expected body %v, got %v", expected, w.Body.String())
	}
}

谢谢。


更多关于Golang中如何为处理器和Postgres编写测试的实战教程也可以访问 https://www.itying.com/category-94-b0.html

1 回复

更多关于Golang中如何为处理器和Postgres编写测试的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


针对你的代码,这里提供两种测试方案:使用接口解耦的单元测试和使用测试容器的集成测试。

方案一:使用接口和依赖注入的单元测试

首先重构代码,引入接口解耦:

// user_repository.go
package main

import (
	"database/sql"
)

type UserRepository interface {
	GetAll() ([]User, error)
	Create(user User) error
	Update(id int, user User) error
	Delete(id int) error
}

type PostgresUserRepository struct {
	db *sql.DB
}

func NewPostgresUserRepository(db *sql.DB) *PostgresUserRepository {
	return &PostgresUserRepository{db: db}
}

func (r *PostgresUserRepository) GetAll() ([]User, error) {
	rows, err := r.db.Query("SELECT id, name, email FROM users")
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var users []User
	for rows.Next() {
		var user User
		err := rows.Scan(&user.ID, &user.Name, &user.Email)
		if err != nil {
			return nil, err
		}
		users = append(users, user)
	}
	return users, nil
}

func (r *PostgresUserRepository) Create(user User) error {
	_, err := r.db.Exec("INSERT INTO users (name, email) VALUES ($1, $2)", user.Name, user.Email)
	return err
}

func (r *PostgresUserRepository) Update(id int, user User) error {
	_, err := r.db.Exec("UPDATE users SET name=$1, email=$2 WHERE id=$3", user.Name, user.Email, id)
	return err
}

func (r *PostgresUserRepository) Delete(id int) error {
	_, err := r.db.Exec("DELETE FROM users WHERE id=$1", id)
	return err
}

更新处理器使用依赖注入:

// handlers.go
package main

import (
	"encoding/json"
	"net/http"
	"strconv"

	"github.com/gorilla/mux"
)

type Handlers struct {
	userRepo UserRepository
}

func NewHandlers(userRepo UserRepository) *Handlers {
	return &Handlers{userRepo: userRepo}
}

func (h *Handlers) GetUsers(w http.ResponseWriter, r *http.Request) {
	users, err := h.userRepo.GetAll()
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(users)
}

func (h *Handlers) AddUser(w http.ResponseWriter, r *http.Request) {
	var user User
	err := json.NewDecoder(r.Body).Decode(&user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	err = h.userRepo.Create(user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.WriteHeader(http.StatusCreated)
	json.NewEncoder(w).Encode(map[string]string{"message": "User added successfully"})
}

func (h *Handlers) UpdateUser(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	id, err := strconv.Atoi(vars["id"])
	if err != nil {
		http.Error(w, "Invalid user ID", http.StatusBadRequest)
		return
	}

	var user User
	err = json.NewDecoder(r.Body).Decode(&user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	err = h.userRepo.Update(id, user)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	json.NewEncoder(w).Encode(map[string]string{"message": "User updated successfully"})
}

func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) {
	vars := mux.Vars(r)
	id, err := strconv.Atoi(vars["id"])
	if err != nil {
		http.Error(w, "Invalid user ID", http.StatusBadRequest)
		return
	}

	err = h.userRepo.Delete(id)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	json.NewEncoder(w).Encode(map[string]string{"message": "User deleted successfully"})
}

创建模拟存储库:

// mock_repository.go
package main

type MockUserRepository struct {
	users []User
	err   error
}

func (m *MockUserRepository) GetAll() ([]User, error) {
	return m.users, m.err
}

func (m *MockUserRepository) Create(user User) error {
	return m.err
}

func (m *MockUserRepository) Update(id int, user User) error {
	return m.err
}

func (m *MockUserRepository) Delete(id int) error {
	return m.err
}

编写单元测试:

// handlers_test.go
package main

import (
	"bytes"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/gorilla/mux"
)

func TestGetUsersHandler(t *testing.T) {
	mockRepo := &MockUserRepository{
		users: []User{
			{ID: 1, Name: "Laky", Email: "laky@example.com"},
			{ID: 2, Name: "Luma", Email: "luma@example.com"},
		},
	}

	handlers := NewHandlers(mockRepo)

	req := httptest.NewRequest("GET", "/users", nil)
	w := httptest.NewRecorder()

	handlers.GetUsers(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("Expected status OK, got %v", w.Code)
	}

	var users []User
	err := json.NewDecoder(w.Body).Decode(&users)
	if err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if len(users) != 2 {
		t.Errorf("Expected 2 users, got %d", len(users))
	}
}

func TestAddUserHandler(t *testing.T) {
	mockRepo := &MockUserRepository{}
	handlers := NewHandlers(mockRepo)

	user := User{Name: "Test User", Email: "test@example.com"}
	body, _ := json.Marshal(user)

	req := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
	req.Header.Set("Content-Type", "application/json")
	w := httptest.NewRecorder()

	handlers.AddUser(w, req)

	if w.Code != http.StatusCreated {
		t.Errorf("Expected status Created, got %v", w.Code)
	}
}

func TestUpdateUserHandler(t *testing.T) {
	mockRepo := &MockUserRepository{}
	handlers := NewHandlers(mockRepo)

	user := User{Name: "Updated User", Email: "updated@example.com"}
	body, _ := json.Marshal(user)

	req := httptest.NewRequest("PUT", "/users/1", bytes.NewBuffer(body))
	req.Header.Set("Content-Type", "application/json")
	w := httptest.NewRecorder()

	router := mux.NewRouter()
	router.HandleFunc("/users/{id}", handlers.UpdateUser).Methods("PUT")
	router.ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("Expected status OK, got %v", w.Code)
	}
}

func TestDeleteUserHandler(t *testing.T) {
	mockRepo := &MockUserRepository{}
	handlers := NewHandlers(mockRepo)

	req := httptest.NewRequest("DELETE", "/users/1", nil)
	w := httptest.NewRecorder()

	router := mux.NewRouter()
	router.HandleFunc("/users/{id}", handlers.DeleteUser).Methods("DELETE")
	router.ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("Expected status OK, got %v", w.Code)
	}
}

func TestAuthMiddleware(t *testing.T) {
	mockRepo := &MockUserRepository{
		users: []User{{ID: 1, Name: "Test", Email: "test@example.com"}},
	}
	handlers := NewHandlers(mockRepo)

	tests := []struct {
		name           string
		authHeader     string
		expectedStatus int
	}{
		{"Valid token", "Bearer secret", http.StatusOK},
		{"Invalid token", "Bearer wrong", http.StatusUnauthorized},
		{"No token", "", http.StatusUnauthorized},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			req := httptest.NewRequest("GET", "/users", nil)
			if tt.authHeader != "" {
				req.Header.Set("Authorization", tt.authHeader)
			}
			w := httptest.NewRecorder()

			handler := authMiddleware(handlers.GetUsers)
			handler.ServeHTTP(w, req)

			if w.Code != tt.expectedStatus {
				t.Errorf("Expected status %v, got %v", tt.expectedStatus, w.Code)
			}
		})
	}
}

方案二:使用testcontainers的集成测试

// integration_test.go
package main

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/gorilla/mux"
	"github.com/testcontainers/testcontainers-go"
	"github.com/testcontainers/testcontainers-go/modules/postgres"
	"github.com/testcontainers/testcontainers-go/wait"
)

func setupTestDatabase(t *testing.T) (*sql.DB, func()) {
	ctx := context.Background()
	
	pgContainer, err := postgres.Run(ctx,
		"postgres:17.1",
		postgres.WithDatabase("testdb"),
		postgres.WithUsername("postgres"),
		postgres.WithPassword("password"),
		testcontainers.WithWaitStrategy(
			wait.ForLog("database system is ready to accept connections").
				WithOccurrence(2).
				WithStartupTimeout(5*time.Second)),
	)
	if err != nil {
		t.Fatalf("Failed to start container: %v", err)
	}

	connStr, err := pgContainer.ConnectionString(ctx)
	if err != nil {
		t.Fatalf("Failed to get connection string: %v", err)
	}

	db, err := sql.Open("postgres", connStr)
	if err != nil {
		t.Fatalf("Failed to connect to database: %v", err)
	}

	_, err = db.Exec(`
		CREATE TABLE IF NOT EXISTS users (
			id SERIAL PRIMARY KEY,
			name VARCHAR(100) NOT NULL,
			email VARCHAR(100) NOT NULL UNIQUE
		)
	`)
	if err != nil {
		t.Fatalf("Failed to create table: %v", err)
	}

	cleanup := func() {
		db.Close()
		pgContainer.Terminate(ctx)
	}

	return db, cleanup
}

func TestIntegrationGetUsers(t *testing.T) {
	db, cleanup := setupTestDatabase(t)
	defer cleanup()

	_, err := db.Exec("INSERT INTO users (name, email) VALUES ($1, $2), ($3, $4)",
		"Laky", "laky@example.com",
		"Luma", "luma@example.com",
	)
	if err != nil {
		t.Fatalf("Failed to insert test data: %v", err)
	}

	repo := NewPostgresUserRepository(db)
	handlers := NewHandlers(repo)

	req := httptest.NewRequest("GET", "/users", nil)
	req.Header.Set("Authorization", "Bearer secret")
	w := httptest.NewRecorder()

	handler := authMiddleware(handlers.GetUsers)
	handler.ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("Expected status OK, got %v", w.Code)
	}

	var users []User
	err = json.NewDecoder(w.Body).Decode(&users)
	if err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if len(users) != 2 {
		t.Errorf("Expected 2 users, got %d", len(users))
	}
}

func TestIntegrationCRUDOperations(t *testing.T) {
	db, cleanup := setupTestDatabase(t)
	defer cleanup()

	repo := NewPostgresUserRepository(db)
	handlers := NewHandlers(repo)
	router := mux.NewRouter()

	router.HandleFunc("/users", authMiddleware(handlers.GetUsers)).Methods("GET")
	router.HandleFunc("/users", authMiddleware(handlers.AddUser)).Methods("POST")
	router.HandleFunc("/users/{id}", authMiddleware(handlers.UpdateUser)).Methods("PUT")
	router.HandleFunc("/users/{id}", authMiddleware(handlers.DeleteUser)).Methods("DELETE")

	t.Run("Create and retrieve user", func(t *testing.T) {
		user := User{Name: "Test User", Email: "test@example.com"}
		body, _ := json.Marshal(user)

		req := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
		req.Header.Set("Authorization", "Bearer secret")
		req.Header.Set("Content-Type", "application/json")
		w := httptest.NewRecorder()
		router.ServeHTTP(w, req)

		if w.Code != http.StatusCreated {
			t.Errorf("Expected status Created, got %v", w.Code)
		}

		req = httptest.NewRequest("GET", "/users", nil)
		req.Header.Set("Authorization", "Bearer secret")
		w = httptest.NewRecorder()
		router.ServeHTTP(w, req)

		var users []User
		json.NewDecoder(w.Body).Decode(&users)

		if len(users) != 1 {
			t.Errorf("Expected 1 user, got %d", len(users))
		}
	})
}

测试运行命令

# 运行单元测试
go test -v -run TestGetUsersHandler
go test -v -run TestAuthMiddleware

# 运行集成测试(需要Docker)
go test -v -tags=integration -run TestIntegration

# 运行所有测试
go test -v ./...

# 生成测试覆盖率报告
go test -coverprofile=coverage.out
go tool cover -html=coverage.out

测试依赖

// go.mod 需要添加的依赖
module your-module-name

go 1.21

require (
    github.com/gorilla/mux v1.8.1
    github.com/lib/pq v1.10.9
    github.com/testcontainers/testcontainers-go v0.30.0
    github.com/testcontainers/testcontainers-go/modules/postgres v0.30.0
)

第一种方案通过接口解耦实现了纯单元测试,不依赖外部服务。第二种方案使用testcontainers创建真实的PostgreSQL容器进行集成测试,更接近生产环境。建议结合使用两种方案,单元测试用于快速验证逻辑,集成测试用于验证数据库交互。

回到顶部