Golang 限制任务并发执行的 workerpool 使用指南
在 Go 语言中,我们经常需要限制并发执行的任务数量,而不是简单地限制 goroutine 的数量。workerpool 是一种常用的并发模式,它可以有效地控制并发任务的数量,同时复用 goroutine 资源。
常用的 workerpool 实现
以下是几种常用的 workerpool 实现方式:
1. 使用 channel 实现基础 workerpool
package main
import (
"fmt"
"sync"
"time"
)
func worker(id int, jobs <-chan int, results chan<- int) {
for j := range jobs {
fmt.Printf("worker %d started job %d\n", id, j)
time.Sleep(time.Second) // 模拟耗时任务
fmt.Printf("worker %d finished job %d\n", id, j)
results <- j * 2
}
}
func main() {
const numJobs = 10
const numWorkers = 3
jobs := make(chan int, numJobs)
results := make(chan int, numJobs)
// 启动 workers
for w := 1; w <= numWorkers; w++ {
go worker(w, jobs, results)
}
// 发送任务
for j := 1; j <= numJobs; j++ {
jobs <- j
}
close(jobs)
// 收集结果
for a := 1; a <= numJobs; a++ {
<-results
}
}
2. 使用 errgroup 实现带错误处理的 workerpool
package main
import (
"context"
"fmt"
"golang.org/x/sync/errgroup"
"time"
)
func main() {
const numTasks = 10
const maxConcurrency = 3
g, ctx := errgroup.WithContext(context.Background())
g.SetLimit(maxConcurrency)
for i := 0; i < numTasks; i++ {
taskID := i
g.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
default:
fmt.Printf("Starting task %d\n", taskID)
time.Sleep(time.Second) // 模拟耗时任务
fmt.Printf("Finished task %d\n", taskID)
return nil
}
})
}
if err := g.Wait(); err != nil {
fmt.Printf("Error occurred: %v\n", err)
} else {
fmt.Println("All tasks completed successfully")
}
}
3. 使用第三方库 workerpool
一个流行的 workerpool 实现是 github.com/gammazero/workerpool
:
package main
import (
"fmt"
"time"
"github.com/gammazero/workerpool"
)
func main() {
wp := workerpool.New(3) // 最大并发数为3
for i := 0; i < 10; i++ {
taskID := i
wp.Submit(func() {
fmt.Printf("Starting task %d\n", taskID)
time.Sleep(time.Second) // 模拟耗时任务
fmt.Printf("Finished task %d\n", taskID)
})
}
wp.StopWait() // 等待所有任务完成
fmt.Println("All tasks completed")
}
高级 workerpool 功能
1. 带优先级的 workerpool
package main
import (
"container/heap"
"fmt"
"sync"
"time"
)
type Task struct {
priority int
fn func()
}
type PriorityQueue []*Task
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority // 数字越大优先级越高
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
}
func (pq *PriorityQueue) Push(x interface{}) {
item := x.(*Task)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
*pq = old[0 : n-1]
return item
}
type PriorityWorkerPool struct {
taskQueue chan *Task
wg sync.WaitGroup
}
func NewPriorityWorkerPool(workers int) *PriorityWorkerPool {
p := &PriorityWorkerPool{
taskQueue: make(chan *Task, 100),
}
for i := 0; i < workers; i++ {
p.wg.Add(1)
go p.worker()
}
return p
}
func (p *PriorityWorkerPool) worker() {
defer p.wg.Done()
for task := range p.taskQueue {
task.fn()
}
}
func (p *PriorityWorkerPool) Submit(priority int, fn func()) {
p.taskQueue <- &Task{priority: priority, fn: fn}
}
func (p *PriorityWorkerPool) Stop() {
close(p.taskQueue)
p.wg.Wait()
}
func main() {
pool := NewPriorityWorkerPool(3)
for i := 0; i < 10; i++ {
taskID := i
priority := i % 3
pool.Submit(priority, func() {
fmt.Printf("Starting task %d with priority %d\n", taskID, priority)
time.Sleep(time.Second)
fmt.Printf("Finished task %d\n", taskID)
})
}
pool.Stop()
}
2. 带超时控制的 workerpool
package main
import (
"context"
"fmt"
"sync"
"time"
)
type TimeoutWorkerPool struct {
wg sync.WaitGroup
taskQueue chan func(ctx context.Context)
}
func NewTimeoutWorkerPool(workers int) *TimeoutWorkerPool {
p := &TimeoutWorkerPool{
taskQueue: make(chan func(ctx context.Context), 100),
}
for i := 0; i < workers; i++ {
p.wg.Add(1)
go p.worker()
}
return p
}
func (p *TimeoutWorkerPool) worker() {
defer p.wg.Done()
for task := range p.taskQueue {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
task(ctx)
cancel()
}
}
func (p *TimeoutWorkerPool) Submit(fn func(ctx context.Context)) {
p.taskQueue <- fn
}
func (p *TimeoutWorkerPool) Stop() {
close(p.taskQueue)
p.wg.Wait()
}
func main() {
pool := NewTimeoutWorkerPool(3)
for i := 0; i < 5; i++ {
taskID := i
pool.Submit(func(ctx context.Context) {
select {
case <-time.After(time.Duration(taskID+1) * time.Second):
fmt.Printf("Task %d completed\n", taskID)
case <-ctx.Done():
fmt.Printf("Task %d timed out\n", taskID)
}
})
}
pool.Stop()
}
选择合适的 workerpool
选择 workerpool 实现时需要考虑以下因素:
- 简单性:对于基本需求,channel 实现的简单 workerpool 就足够了
- 错误处理:如果需要错误处理,errgroup 是更好的选择
- 高级功能:如果需要优先级、超时控制等高级功能,可以考虑专门的库或自定义实现
- 性能:对于高性能场景,需要评估不同实现的性能特点
希望这些示例能帮助你理解和使用 Go 中的 workerpool 模式!