I build the pipline with thread-pool functions and pass context.Context in it as argument. When cancel() function called, or timeout expired the pipeline must terminate gracefully so that there are no working goroutines left.
functions I work with:
func generate(amount int) <-chan int {
result := make(chan int)
go func() {
defer close(result)
for i := 0; i < amount; i++ {
result <- i
}
}()
return result
}
func sum(input <-chan int) int {
result := 0
for el := range input {
result += el
}
return result
}
func process[T any, R any](ctx context.Context, workers int, input <-chan T, do func(T) R) <-chan R {
wg := new(sync.WaitGroup)
result := make(chan R)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case val, ok := <-input:
if !ok {
return
}
select {
case <-ctx.Done():
return
case result <- do(val):
}
}
}
}()
}
go func() {
defer close(result)
wg.Wait()
}()
return result
}
Usage:
func main() {
ctx, cancel := context.WithTimeout(context.Background(), 1200*time.Millisecond)
defer cancel()
input := generate(1000)
multiplied := process(ctx, 15, input, func(val int) int {
time.Sleep(time.Second)
return val * 2
})
increased := process(ctx, 15, multiplied, func(val int) int {
return val + 10
})
fmt.Println("Result: ", sum(increased)) // 360 is ok
fmt.Println("Num goroutine: ", runtime.NumGoroutine()) // 18 is too much
}
I understand that this happened because all the increase goroutines ended, while the multiply goroutines were still running.
Is there any canonical way to solve this problem?
You expecting something like structured concurrency, so all goroutines should end at the end of the current scope, but do not design your code according to your expectations. You'll leak generate
when the input channel is not depleted and your do
functions are not cancellable.
Adding cancelability to generate
and your do
functions helps a little:
package main
import (
"context"
"fmt"
"runtime"
"sync"
"time"
)
func main() {
ctx, cancel := context.WithTimeout(context.Background(), 1200*time.Millisecond)
defer cancel()
input := generate(ctx, 1_000)
multiplied := process(ctx, 15, input, func(ctx context.Context, val int) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(time.Second):
return val * 2, nil
}
})
increased := process(ctx, 15, multiplied, func(_ context.Context, val int) (int, error) {
return val + 10, nil
})
fmt.Println("Result: ", sum(increased)) // 360 is ok
fmt.Println("Num goroutine: ", runtime.NumGoroutine()) // 18 is too much
}
func generate(ctx context.Context, amount int) <-chan int {
input := make(chan int)
go func() {
defer close(input)
for i := 0; i < amount; i++ {
select {
case <-ctx.Done():
return
case input <- i:
}
}
}()
return input
}
func sum(input <-chan int) int {
result := 0
for el := range input {
result += el
}
return result
}
func process[T any, R any](ctx context.Context, workers int, input <-chan T, do func(context.Context, T) (R, error)) <-chan R {
wg := new(sync.WaitGroup)
result := make(chan R)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case val, ok := <-input:
if !ok {
return
}
r, err := do(ctx, val)
if err != nil {
return
}
result <- r
}
}
}()
}
go func() {
defer close(result)
wg.Wait()
}()
return result
}
More is mentioned in āAdvanced Go Concurrency Patternsā, but as a general recommendation I would advise to write synchronous code first when you aim for structured concurrency and later work to run them concurrently it.