Go errgroup + 信号量做 worker pool(不引入第三方)

起因

要并发处理 10 万个 URL 抓取。原生 goroutine + channel 写起来麻烦:

  • 限制并发数(不能开 10 万 goroutine 抓)
  • 错误传播(一个失败要么 cancel 其它要么继续)
  • 等所有任务完成
  • 收集结果

sync/errgroup + golang.org/x/sync/semaphore 标准库组合解决,
不需要 ants / workerpool 第三方。

解决方案

1. 简单 errgroup

package main

import (
    "context"
    "fmt"
    "net/http"
    "time"

    "golang.org/x/sync/errgroup"
)

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()

    g, ctx := errgroup.WithContext(ctx)

    urls := []string{
        "https://example.com",
        "https://golang.org",
        "https://github.com",
    }
    results := make([]int, len(urls))

    for i, u := range urls {
        i, u := i, u   // capture range vars
        g.Go(func() error {
            req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
            if err != nil {
                return err
            }
            resp, err := http.DefaultClient.Do(req)
            if err != nil {
                return err
            }
            defer resp.Body.Close()
            results[i] = resp.StatusCode
            return nil
        })
    }

    if err := g.Wait(); err != nil {
        fmt.Println("error:", err)
    }
    fmt.Println(results)
}

关键点:

  • g.Go(func() error) 启动一个 goroutine
  • 任一 goroutine 返回 error → errgroup 自动 cancel ctx → 其它 goroutine
    收到 cancel signal 退出
  • g.Wait() 等所有完成,返回第一个 error

2. 限并发:SetLimit (Go 1.20+)

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(50)   // 最多 50 个 goroutine 并发

for _, u := range urls {
    u := u
    g.Go(func() error {
        return fetch(ctx, u)
    })
}
g.Wait()

SetLimit(50) + Go() 会阻塞 caller 直到有空位。
10 万 URL 排队,永远不超过 50 个并发。

3. 老 Go 版本:semaphore

import "golang.org/x/sync/semaphore"

sem := semaphore.NewWeighted(50)
g, ctx := errgroup.WithContext(ctx)

for _, u := range urls {
    u := u
    if err := sem.Acquire(ctx, 1); err != nil {
        break
    }
    g.Go(func() error {
        defer sem.Release(1)
        return fetch(ctx, u)
    })
}
g.Wait()

semaphore 是计数信号量,权重可以不是 1(如内存敏感任务每个占 4)。

4. 结果收集:channel

type result struct {
    url string
    code int
    err error
}

results := make(chan result, len(urls))

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(50)
for _, u := range urls {
    u := u
    g.Go(func() error {
        code, err := fetch(ctx, u)
        results <- result{u, code, err}
        return nil   // 不让 errgroup 因单个 fetch 失败 cancel 其它
    })
}

go func() {
    g.Wait()
    close(results)
}()

for r := range results {
    fmt.Println(r)
}

注意:

  • 用 buffered channel 避免阻塞
  • 单个 fetch 错误不让 errgroup cancel(包成 result 一起传)
  • 单独 goroutine wait + close channel

5. 真正的 worker pool

如果 task 是 stream 进来(不预知数量),用 channel + worker:

type Task struct{ URL string }

func RunPool(ctx context.Context, tasks <-chan Task, workers int) <-chan error {
    out := make(chan error, workers)
    var wg sync.WaitGroup

    for i := 0; i < workers; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for t := range tasks {
                select {
                case <-ctx.Done():
                    return
                default:
                }
                if err := process(ctx, t); err != nil {
                    out <- err
                }
            }
        }()
    }

    go func() {
        wg.Wait()
        close(out)
    }()
    return out
}

// 用
tasks := make(chan Task, 100)
go func() {
    defer close(tasks)
    for _, u := range urls {
        tasks <- Task{URL: u}
    }
}()

errors := RunPool(ctx, tasks, 50)
for err := range errors {
    log.Printf("err: %v", err)
}

stream 处理 + back-pressure(producer 阻塞在 tasks 满时)。

6. context cancel 在 worker 里

worker 函数必须响应 ctx.Done(),否则 cancel 没用:

func fetch(ctx context.Context, url string) error {
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)  // 关键
    if err != nil {
        return err
    }
    resp, err := http.DefaultClient.Do(req)   // ctx cancel 时立刻断
    // ...
}

http.NewRequestWithContext / db.QueryContext / redis.Get(ctx)
所有 IO 都用 ctx 版本。

7. 实际 benchmark

10 万 URL,50 并发:

写法 总时长 内存峰值
串行 30 min+ 50 MB
无限制 goroutine 1.5 min(多数失败) 800 MB
errgroup SetLimit(50) 4 min(全成功) 80 MB
worker pool channel 4 min 70 MB

无限制 goroutine 看似快但实际:DNS 解析失败、连接被 ban、
socket fd 耗尽。永远要限并发

效果

  • 10 万 task 在 5 分钟内可控完成
  • 任意 task 失败 ctx cancel,剩余资源不浪费
  • 内存控制在 50-100 MB,不会 OOM
  • 代码 50 行,比引第三方库少依赖

何时还是要第三方库

  • panicking worker 自动 recover:errgroup 不 recover panic。
    生产建议每个 worker 函数 defer recover()
  • 复杂 retry / backoff 策略:用 github.com/cenkalti/backoff/v4
  • 分布式(跨机器)worker pool:用 Asynq / Machinery
  • 动态 worker 数:errgroup SetLimit 启动后不能改

踩过的坑

  1. range var 没 capture
    go for _, u := range urls { g.Go(func() error { return fetch(u) // ❌ 所有 goroutine 看到 last u }) }
    Go 1.22 才修了 range var 默认 scope。1.21- 必须 u := u 显式
    capture。

  2. g.Wait 不调:goroutine leak。
    go if err := g.Wait(); err != nil { ... }
    永远要 Wait。

  3. g.Go 里 panic → 整个程序 crash(errgroup 不 recover)。
    生产 worker 函数 wrap:
    go g.Go(func() (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("panic: %v", r) } }() return fetch(ctx, u) })

  4. 共享 slice 写:上面 results[i] = ... 不需要锁(每个 worker 写
    不同 index),但如果是 results = append(results, ...) 则要 sync.Mutex
    或 channel。

  5. HTTP client default reusehttp.DefaultClient 全局共享 +
    connection pool。多个 goroutine 用 OK。但默认无超时——务必传 ctx
    或自己 Client{ Timeout: 30 * time.Second }

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。