从零开始用Golang构建CLI压力测试工具——该从哪里入手?

从零开始用Golang构建CLI压力测试工具——该从哪里入手? 我开始从零构建我的CLI工具,感觉有点迷茫。 我最终决定挑战自己,不使用任何AI生成工具。但这有点困难,因为我真的已经习惯了依赖它们。

到目前为止,我已经决定这个工具将接收一个IP地址作为参数,并向服务器发起大量并发请求(这是我希望能学习并发知识的地方,所以这个想法一直萦绕在我脑海中)。在输出中,它将提供格式美观的信息,包括已处理和失败的请求数量、延迟等。

我想这应该相当简单。我希望在完成HTTP部分后,能增加一些复杂度,例如尝试实现WebSocket的标志参数。

这个想法也让我觉得很有趣,因为我需要处理一些安全问题——如果我理解正确的话,我应该实现一些基本的验证逻辑(比如SSH密钥之类的),以防止这个工具被用于黑帽类型的场景。

如果你有任何资料或建议,请务必告诉我。我的目标是完全不使用AI或任何外部库来完成它,以利于真正的学习。


更多关于从零开始用Golang构建CLI压力测试工具——该从哪里入手?的实战教程也可以访问 https://www.itying.com/category-94-b0.html

2 回复

也许《Go程序设计语言》的这一章会有所帮助:

ch1.pdf (1046.44 KB)

请翻到第17页,查看“1.6. 并发获取URL”。并参考附带的代码:

// Copyright © 2016 Alan A. A. Donovan & Brian W. Kernighan.
// License: https://creativecommons.org/licenses/by-nc-sa/4.0/

// See page 17.
//!+

// Fetchall fetches URLs in parallel and reports their times and sizes.
package main

import (
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"os"
	"time"
)

func main() {
	start := time.Now()

更多关于从零开始用Golang构建CLI压力测试工具——该从哪里入手?的实战系列教程也可以访问 https://www.itying.com/category-94-b0.html


从零构建Golang CLI压力测试工具的实现方案

1. 基础项目结构

首先创建基本的CLI框架,处理命令行参数:

// main.go
package main

import (
    "flag"
    "fmt"
    "os"
)

type Config struct {
    URL        string
    Requests   int
    Concurrency int
    Timeout    int
}

func main() {
    config := parseFlags()
    
    fmt.Printf("开始压力测试: %s\n", config.URL)
    fmt.Printf("总请求数: %d, 并发数: %d\n", config.Requests, config.Concurrency)
    
    // 启动压力测试
    runStressTest(config)
}

func parseFlags() *Config {
    config := &Config{}
    
    flag.StringVar(&config.URL, "url", "", "目标URL (必需)")
    flag.IntVar(&config.Requests, "requests", 100, "总请求数")
    flag.IntVar(&config.Concurrency, "concurrency", 10, "并发数")
    flag.IntVar(&config.Timeout, "timeout", 30, "超时时间(秒)")
    
    flag.Parse()
    
    if config.URL == "" {
        fmt.Println("错误: 必须提供 -url 参数")
        os.Exit(1)
    }
    
    return config
}

2. 并发HTTP压力测试核心实现

使用goroutine和channel实现并发控制:

// stress.go
package main

import (
    "fmt"
    "net/http"
    "sync"
    "time"
)

type Result struct {
    Duration time.Duration
    Status   int
    Error    bool
}

type Statistics struct {
    TotalRequests   int
    FailedRequests  int
    SuccessRequests int
    TotalDuration   time.Duration
    MinDuration     time.Duration
    MaxDuration     time.Duration
    mu              sync.Mutex
}

func runStressTest(config *Config) {
    stats := &Statistics{
        MinDuration: time.Hour,
    }
    
    startTime := time.Now()
    
    // 创建工作池
    jobs := make(chan int, config.Requests)
    results := make(chan Result, config.Requests)
    
    // 启动worker
    var wg sync.WaitGroup
    for i := 0; i < config.Concurrency; i++ {
        wg.Add(1)
        go worker(i, config, jobs, results, &wg)
    }
    
    // 分发任务
    for i := 0; i < config.Requests; i++ {
        jobs <- i
    }
    close(jobs)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    // 处理结果
    for result := range results {
        stats.update(result)
    }
    
    elapsed := time.Since(startTime)
    stats.printReport(elapsed)
}

func worker(id int, config *Config, jobs <-chan int, results chan<- Result, wg *sync.WaitGroup) {
    defer wg.Done()
    
    client := &http.Client{
        Timeout: time.Duration(config.Timeout) * time.Second,
    }
    
    for range jobs {
        start := time.Now()
        
        resp, err := client.Get(config.URL)
        duration := time.Since(start)
        
        result := Result{
            Duration: duration,
            Error:    err != nil,
        }
        
        if err == nil {
            result.Status = resp.StatusCode
            resp.Body.Close()
        }
        
        results <- result
    }
}

func (s *Statistics) update(r Result) {
    s.mu.Lock()
    defer s.mu.Unlock()
    
    s.TotalRequests++
    s.TotalDuration += r.Duration
    
    if r.Error || (r.Status >= 400 && r.Status < 600) {
        s.FailedRequests++
    } else {
        s.SuccessRequests++
    }
    
    if r.Duration < s.MinDuration {
        s.MinDuration = r.Duration
    }
    if r.Duration > s.MaxDuration {
        s.MaxDuration = r.Duration
    }
}

func (s *Statistics) printReport(elapsed time.Duration) {
    avgDuration := time.Duration(0)
    if s.TotalRequests > 0 {
        avgDuration = s.TotalDuration / time.Duration(s.TotalRequests)
    }
    
    fmt.Println("\n=== 压力测试报告 ===")
    fmt.Printf("总耗时: %v\n", elapsed)
    fmt.Printf("总请求数: %d\n", s.TotalRequests)
    fmt.Printf("成功请求: %d (%.1f%%)\n", s.SuccessRequests, 
        float64(s.SuccessRequests)/float64(s.TotalRequests)*100)
    fmt.Printf("失败请求: %d (%.1f%%)\n", s.FailedRequests,
        float64(s.FailedRequests)/float64(s.TotalRequests)*100)
    fmt.Printf("平均延迟: %v\n", avgDuration)
    fmt.Printf("最小延迟: %v\n", s.MinDuration)
    fmt.Printf("最大延迟: %v\n", s.MaxDuration)
    fmt.Printf("请求速率: %.1f req/s\n", 
        float64(s.TotalRequests)/elapsed.Seconds())
}

3. 安全验证实现

添加基本的安全检查防止滥用:

// security.go
package main

import (
    "fmt"
    "net/url"
    "os"
    "strings"
)

func validateTarget(target string) error {
    u, err := url.Parse(target)
    if err != nil {
        return fmt.Errorf("无效的URL: %v", err)
    }
    
    // 禁止本地地址
    if isLocalhost(u.Hostname()) {
        return fmt.Errorf("禁止测试本地地址")
    }
    
    // 检查常见黑名单
    if isBlacklisted(u.Hostname()) {
        return fmt.Errorf("目标地址在黑名单中")
    }
    
    return nil
}

func isLocalhost(hostname string) bool {
    localhosts := []string{
        "localhost",
        "127.0.0.1",
        "::1",
        "0.0.0.0",
    }
    
    for _, lh := range localhosts {
        if hostname == lh || strings.HasPrefix(hostname, lh+".") {
            return true
        }
    }
    return false
}

func isBlacklisted(hostname string) bool {
    // 这里可以添加自定义黑名单逻辑
    // 例如:禁止测试特定域名
    blacklist := []string{
        "example.com",
        "test.com",
    }
    
    for _, bl := range blacklist {
        if strings.Contains(hostname, bl) {
            return true
        }
    }
    return false
}

// 在main函数中添加验证
func main() {
    config := parseFlags()
    
    // 安全验证
    if err := validateTarget(config.URL); err != nil {
        fmt.Printf("安全验证失败: %v\n", err)
        os.Exit(1)
    }
    
    // ... 其余代码
}

4. WebSocket支持扩展

添加WebSocket测试功能:

// websocket.go
package main

import (
    "flag"
    "time"
    "github.com/gorilla/websocket"
)

type WSConfig struct {
    UseWebSocket bool
    Message      string
}

var wsConfig WSConfig

func init() {
    flag.BoolVar(&wsConfig.UseWebSocket, "ws", false, "使用WebSocket协议")
    flag.StringVar(&wsConfig.Message, "ws-message", "test", "WebSocket测试消息")
}

func runWebSocketTest(config *Config) {
    // WebSocket连接实现
    dialer := websocket.Dialer{
        HandshakeTimeout: time.Duration(config.Timeout) * time.Second,
    }
    
    conn, _, err := dialer.Dial(config.URL, nil)
    if err != nil {
        // 处理错误
        return
    }
    defer conn.Close()
    
    // WebSocket消息测试逻辑
    // ...
}

5. 构建和使用

创建go.mod文件:

go mod init stress-tool

构建工具:

go build -o stress-tool

使用示例:

# HTTP测试
./stress-tool -url https://example.com -requests 1000 -concurrency 50

# 带安全验证
./stress-tool -url http://localhost:8080  # 会被拒绝

# WebSocket测试
./stress-tool -url ws://example.com/ws -ws -requests 500

6. 性能优化建议

  • 使用sync.Pool重用对象减少GC压力
  • 实现连接池复用HTTP连接
  • 添加速率限制控制
  • 实现实时进度显示

这个实现完全使用标准库,没有外部依赖,涵盖了并发控制、错误处理、安全验证等核心概念。WebSocket部分需要gorilla/websocket库,但你可以按照类似模式实现自己的WebSocket客户端。

回到顶部