DEV Community

Cover image for Generic Concurrency in Go
Sergey Kamardin
Sergey Kamardin

Posted on • Originally published at sergey.kamardin.org

Generic Concurrency in Go

Hello Gophers!

In this article, I want to share my thoughts and ideas that I've accumulated over time regarding generics in Go, and in particular, concurrency patterns, which now can become more reusable and convenient with the use of generics.

TL;DR

Generics and goroutines (and iterators in the future) are great tools we can leverage to have reusable general purpose concurrent processing in our programs.

In this article we explore the possibilities of combining them together.

Introduction

Let's quickly touch a surface with some basic context and small examples to see what problem generics solve and how we can fuse it existing concurrency model.

In this article we are going to think a lot about mapping of collections (sets, sequences) of elements. So the mapping is a process that results in a new collection of elements where each element is a result of a call to some function f() with the corresponding element from the initial collection.

Pre-Generics era

Let's define the first simple integer numbers mapping (which in Go snippets we will call transform() to not confuse with the builtin map type):

func transform([]int, func(int) int) []int
Enter fullscreen mode Exit fullscreen mode

Sample implementation
func transform(xs []int, f func(int) int) []int {
    ret := make([]int, len(xs))
    for i, x := range xs {
        ret[i] = f(x)
    }
    return ret
}
Enter fullscreen mode Exit fullscreen mode

An example use of such function would look like this:

// Output: [1, 4, 9]
transform([]int{1, 2, 3}, func(n int) int {
    return n * n
})
Enter fullscreen mode Exit fullscreen mode

Now lets assume we want to map integers to strings. That's easy -- we can define transform() just slightly different:

func transform([]int, func(int) string) []string
Enter fullscreen mode Exit fullscreen mode

So we can use it this way:

// Output: ["1", "2", "3"]
transform([]int{1, 2, 3}, strconv.Itoa) 
Enter fullscreen mode Exit fullscreen mode

What about reporting whether a number is odd or even? Just another tiny correction:

func transform([]int, func(int) bool) []bool
Enter fullscreen mode Exit fullscreen mode

So we could use it this way:

// Output: [false, true, false]
transform([]int{1, 2, 3}, func(n int) bool {
    return n % 2 == 0
})
Enter fullscreen mode Exit fullscreen mode

Generalising the corrections of transform() we've made above for each use case, we can say that regardless of the types it operates on, it does exactly the same thing over and over again. If we were to generate the code for each type involved, we could use a template that would look like this:

func transform_{{ .A }}_{{ .B }}([]{{ .A }}, func({{ .A }}) {{ .B }}) []{{ .B }}

// transform_int_int([]int, func(int) int) []int
// transform_int_string([]int, func(int) string) []string
// transform_int_bool([]int, func(int) bool) []bool
Enter fullscreen mode Exit fullscreen mode

Actually there were a few nice code generation tools that were doing almost this templating for pre-generic versions of Go. genny is just one example.

Generics era

Thanks to the generics, we now have an ability to parametrize functions and types with type parameters and define tranform() this way:

func transform[A, B any]([]A, func(A) B) []B
Enter fullscreen mode Exit fullscreen mode

And the implementation changes just a little bit!
func transform[A, B any](xs []A, f func(A) B) []B {
    ret := make([]B, len(xs))
    for i, x := range xs {
        ret[i] = f(x)
    }
    return ret
}
Enter fullscreen mode Exit fullscreen mode

So we can use it now for any input and output types (assuming we have func square(int) int and isEven(int) bool defined as we used them above somewhere in the package):

transform([]int{1, 2, 3}, square)       // [1, 4, 9]
transform([]int{1, 2, 3}, strconv.Itoa) // ["1", "2", "3"]
transform([]int{1, 2, 3}, isEven)       // [false, true, false]
Enter fullscreen mode Exit fullscreen mode

Concurrent mapping

Okay, now let's get on to the main subject of this article and focus on concurrency patterns that can benefit from generics.

The x/sync/errgroup package

Before jumping into lots of coding snippets, let's make a tiny step aside and look at (very popular) golang.org/x/sync/errgroup Go library. In short, it allows you to start various number of goroutines to perform different tasks and wait for their completion or failure.

It is supposed to be used this way:

// Create workers group and a context which will get canceled if any of the
// tasks fails.
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
    return doSomeFun(gctx)
})
g.Go(func() error {
    return doEvenMoreFun(gctx)
})
if err := g.Wait(); err != nil {
    // handle error
}
Enter fullscreen mode Exit fullscreen mode

The reason I mentioned the package is because, when viewed from a slightly different and a bit generalised perspective, it essentially looks to be the same mapping thing. The package allows you to concurrently map a set of tasks into a corresponding set of results and provides a generalised way for errors handling and propagation, as well as cancellation of subtasks (via context cancellation) if any of them fails.

In this article we want to build something similar, and, as the repeated use of the "generic" word suggests, we will be doing this in a generic way.

Naive implementation

Getting back to the transform() function. Let's assume that all the calls to f() can be done concurrently without breaking our (or anyone else's) program. Then we can start with this naive concurrent implementation:

func transform[A, B any](as []A, f func(A) B) []B {
    bs := make([]B, len(as))

    var wg sync.WaitGroup
    for i := 0; i < len(as); i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            bs[i] = f(as[i])
        }(i)
    }
    wg.Wait()

    return bs
}
Enter fullscreen mode Exit fullscreen mode

That is, we start a goroutine per each element of the input and call f(elem). Then we store the result at the corresponding index in the shared slice bs. No context, no cancellations, no errors even -- this one doesn't look like something very helpful in anything besides pure computation.

Context cancellation

In real world many or even most of the concurrent tasks, especially the i/o related, would be controlled by context.Context instance. Since there is a context, there could be timeout or cancellation. Let's think of it this way (here and after I'll highlight the lines that were added compared to the previous code sample):

func transform[A, B any](
    ctx context.Context,
    as []A,
    f func(context.Context, A) (B, error),
) (
    []B,
    error,
) {
    bs := make([]B, len(as))
    es := make([]error, len(as))

    subctx, cancel := context.WithCancel(ctx)
    defer cancel()

    var wg sync.WaitGroup
    for i := 0; i < len(as); i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            bs[i], es[i] = f(subctx, as[i])
            if es[i] != nil {
                cancel()
            }
        }(i)
    }
    wg.Wait()

    err := errors.Join(es...)
    if err != nil {
        return nil, err
    }
    return bs, nil
}
Enter fullscreen mode Exit fullscreen mode

Now we have one more shared slice es to store errors potentially returned by f(). If any goroutine's f() fails, we cancel the entire transform() context and expect every inflight f() call to respect the cancellation and return as soon as possible.

Limiting concurrency

In reality, we cannot assume too much about f() implicitly. Users of transform() might want to limit the number of concurrent calls to f(). For example, f() can map a url to the result of an http request. Without any limits we can overwhelm the server or get banned ourselves.

Let's not think about the parameters structure for now, and just add a parallelism int argument to the function arguments.

At this point we need to switch from using sync.WaitGroup to a semaphore chan, as we want to control the (maximum) number of simultaneously running goroutines as well as to handle the context cancellation, both by using select.

func transform[A, B any](
    ctx context.Context,
    parallelism int,
    as []A,
    f func(context.Context, A) (B, error),
) (
    []B,
    error,
) {
    bs := make([]B, len(as))
    es := make([]error, len(as))

    // FIXME: if the given context is already cancelled, no worker will be
    // started but the transform() call will return bs, nil.
    subctx, cancel := context.WithCancel(ctx)
    defer cancel()

    sem := make(chan struct{}, parallelism)
sched:
    for i := 0; i < len(as); i++ {
        // We are checking the sub-context cancellation here, in addition to
        // the user-provided context, to handle cases where f() returns an
        // error, which leads to the termination of transform.
        if subctx.Err() != nil {
            break
        }
        select {
        case <-subctx.Done():
            break sched

        case sem <- struct{}{}:
            // Being able to send a tick into the channel means we can start a
            // new worker goroutine. This could be either due to the completion
            // of a previous goroutine or because the number of started worker
            // goroutines is less than the given parallism value.
        }
        go func(i int) {
            defer func() {
                // Signal that the element has been processed and the worker
                // goroutine has completed.
                <-sem
            }()
            bs[i], es[i] = f(subctx, as[i])
            if es[i] != nil {
                cancel()
            }
        }(i)
    }
    // Since each goroutine reads off one tick from the semaphore before exit,
    // filling the channel with artificial ticks makes us sure that all started
    // goroutines completed their execution.
    //
    // FIXME: for the high values of parallelism this loop becomes slow.
    for i := 0; i < cap(sem); i++ {
        // NOTE: we do not check the user-provided context here because we want
        // to return from this function only when all the started worker
        // goroutines have completed. This is to avoid surprising users with
        // some of the f() function calls still running in the background after
        // transform() returns.
        //
        // This implies f() should respect context cancellation and return as
        // soon as its context gets cancelled.
        sem <- struct{}{}
    }

    err := errors.Join(es...)
    if err != nil {
        return nil, err
    }
    return bs, nil
}
Enter fullscreen mode Exit fullscreen mode

For this and next iterations of tranform() we actually could leave the implementation as it is now and leave both use cases at the mercy of f() implementation. For example, we could just start N goroutines regardless the concurrency limits and let the user of transform() to partially serialise them the way they want to. That would require an overhead of starting N goroutines instead of P (where P is the "parallelism" limit, which can be much less than N). It also would imply some compute overhead on synchronisation of the goroutines depending on the mechanism used. Since all of this is unnecessary, we proceed with the implementation the hard way, but for many of the cases this complications are optional.

Example user based implementation
// Initialised x/time/rate.Limiter instance.
var lim *rate.Limiter
transform(ctx, as, func(_ context.Context, url string) (int, error) {
    if err := lim.Wait(ctx); err != nil {
        return 0, err
    }

    // Process url.

    return 42, nil
})
Enter fullscreen mode Exit fullscreen mode

Reusing goroutines

In the previous iteration we were starting a goroutine per each task, but no more parallelism goroutines at a time. This highlights another interesting option -- users might want to have a custom execution context per each goroutine. For example, suppose we have N tasks with maximum P running concurrently (and P can be significantly less than N). If each task requires some form of resource preparation, such as a large memory allocation, a database session, or maybe a single-threaded Cgo "coroutine", it would seem logical to prepare only P resources and reuse them among workers through context.

Again, let's keep the structure of passing options aside.

func transform[A, B any](
    ctx context.Context,
    prepare func(context.Context) (context.Context, context.CancelFunc),
    parallelism int,
    as []A,
    f func(context.Context, A) (B, error),
) (
    []B,
    error,
) {
    bs := make([]B, len(as))
    es := make([]error, len(as))

    // FIXME: if the given context is already cancelled, no worker will be
    // started but the transform() call will return bs, nil.
    subctx, cancel := context.WithCancel(ctx)
    defer cancel()

    sem := make(chan struct{}, parallelism)
    wrk := make(chan int)
sched:
    for i := 0; i < len(as); i++ {
        // We are checking the sub-context cancellation here, in addition to
        // the user-provided context, to handle cases where f() returns an
        // error, which leads to the termination of transform.
        if subctx.Err() != nil {
            break
        }
        select {
        case <-subctx.Done():
            break sched

        case wrk <- i:
            // There is an idle worker goroutine that is ready to process the
            // next element.
            continue

        case sem <- struct{}{}:
            // Being able to send a tick into the channel means we can start a
            // new worker goroutine. This could be either due to the completion
            // of a previous goroutine or because the number of started worker
            // goroutines is less than the given parallism value.
        }
        go func(i int) {
            defer func() {
                // Signal that the element has been processed and the worker
                // goroutine has completed.
                <-sem
            }()

            // Capture the subctx from the dispatch loop. This prevents
            // overriding it if the given prepare() function is not nil.
            subctx := subctx
            if prepare != nil {
                var cancel context.CancelFunc
                subctx, cancel = prepare(subctx)
                defer cancel()
            }
            for {
                bs[i], es[i] = f(subctx, as[i])
                if es[i] != nil {
                    cancel()
                    return
                }
                var ok bool
                i, ok = <-wrk
                if !ok {
                    // Work channel has been closed, which means we will not
                    // get any new tasks for this worker and can return.
                    break
                }
            }
        }(i)
    }
    // Since each goroutine reads off one tick from the semaphore before exit,
    // filling the channel with artificial ticks makes us sure that all started
    // goroutines completed their execution.
    //
    // FIXME: for the high values of parallelism this loop becomes slow.
    for i := 0; i < cap(sem); i++ {
        // NOTE: we do not check the user-provided context here because we want
        // to return from this function only when all the started worker
        // goroutines have completed. This is to avoid surprising users with
        // some of the f() function calls still running in the background after
        // transform() returns.
        //
        // This implies f() should respect context cancellation and return as
        // soon as its context gets cancelled.
        sem <- struct{}{}
    }

    err := errors.Join(es...)
    if err != nil {
        return nil, err
    }
    return bs, nil
}
Enter fullscreen mode Exit fullscreen mode

At this point we start up to P goroutines, and distribute tasks across them using non-buffered channel wrk. The channel is non-buffered because we want to have an immediate runtime "feedback" to know if there are any idle workers at the moment or if we should consider starting a new one. Once all the tasks processed or any of the f() calls fails, we signal (by doing close(wrk)) all the started goroutines to return.

As in the previous section, this might be done inside f() too, for example, by using sync.Pool. f() could acquire a resource (or create, in case when there are no idle resources) and release it once it's not needed anymore. Since the set of goroutines is fixed, odds are resources can have a nice CPU locality, so the overhead could be minimal.

Example user based implementation
// Note that this snippet assumes `transform()` can limit its concurrency.
var pool sync.Pool
transform(ctx, 8, as, func(_ context.Context, userID string) (int, error) {
    sess := pool.Get().(*db.Session)
    if sess == nil {
        // Initialise database session.
    }
    defer pool.Put(sess)

    // Process userID.

    return 42, nil
})
Enter fullscreen mode Exit fullscreen mode

Generalisation of transform()

So far our focus has been on mapping slices, which in many cases is enough. However, what if we want to map map types, or maybe chan even? Can we map anything that we can range over? And as in for loops, do we always need to map values really?

These all are interesting questions which lead us to an idea that we can generalise our concurrent iteration approach. We can have a "low level" function that behaves almost the same but doing a bit less assumptions on its input and output. Then, it will take just a little effort to build a bit more specific transform() on top of it. Let's call the function iterate() and represent its input and output as functions instead of data types. We will pull() elements from the input and push() the results back to the user. This way the user of iterate() would control the way it provides the input elements and the way it handles the results.

We also have to consider what results iterate() should push to the user. As we plan to make the mapping of input elements optional, (B, error) doesn't seem to be the only right and obvious option anymore. This part is really subtle actually, and maybe majority of use cases of the function would benefit of keeping it as it was and returning the error explictily. However, semantically it doesn't make much sense as the f() result is only being proxied down to the push() call without any processing, which means that iterate() really has no any presumptions on the result. In other words, the result makes sense for the push() function implementation only, which is given by the user. Additionally, this signature will work better with Go iterators, which we'll cover in the end of this article. So having this in mind let's try to reduce number of the return parameters down to one. Since we intend to push results through a function call, we should likely do it in a serialised way. transform() and later iterate() have all the needed synchronisation already internally, so this way the user would collect the results without the need for extra synchronisation efforts on their side.

Another thing to cover is the way we were handling errors earlier -- we did not map an error to the input element that caused it. It is true that f() could wrap an error, but it is more clean for f() to remain unaware of the way it will be called. In other words, f() should not assume it's being called as the iterate() argument. If it was invoked with a single element a, there is no point to wrap a into the error, as it's obvious to the caller that this particular a caused the error. This principle leads us to another observation -- any potential binding of an input element to an error (or any other result) should also occur during the push() execution. For the same reasons, it is the push() function that should control the iteration and decide if a faulty result should interrupt the loop.

Additionally, this iterate() design naturally provides a nice flow control. If the user does something slow in the push() function, the other worker goroutines will eventually pause processing new elements. This is because they will get blocked by sending their f() call results into the res channel, which in turn is being drained by the function that calls push().

As the function code listing becomes too big, let's cover it part by part.

Signature

func iterate[A, B any](
    ctx context.Context,
    prepare func(context.Context) (context.Context, context.CancelFunc),
    parallelism int,
    pull func() (A, bool),
    f func(context.Context, A) B,
    push func(A, B) bool,
) (err error) {
Enter fullscreen mode Exit fullscreen mode

Both arguments and return parameters no longer have input and output slices as it was for the transform() function previously. Instead, input elements are pulled in by calling the pull() function and results are pushed back to the user by calling the push() function. Note that the push() function returns a bool parameter that controls the iteration -- once false is returned, no more push() calls will be made and all ongoing f() executions will get their context canceled. The iterate() returns just an error, which can only be non-nil when the iteration is terminated due to the given ctx cancellation -- otherwise there is no way of knowing why the iteration has stopped.

Even though there are just three cases when loop can be terminated:

  • pull() returned false, meaning no more elements to process.
  • push() returned false, meaning the user doesn't need any further results.
  • the parent ctx got cancelled.

Without user code complications it's hard to say whether all of the elements were processed before the parent context got canceled.

Example
Let's assume we want implement concurrent forEach() using iterate():
func forEach[A any](
    ctx context.Context,
    in []A,
    f func(context.Context, A) error,
) (err error) {
    var i int
    iterate(ctx, nil, 0,
        func() (_ A, ok bool) {
            if i == len(in) {
                return
            }
            i++
            return in[i-1], true
        },
        f,
        func(_ A, e error) bool {
            err = e
            return e == nil
        },
    )
    if err == nil {
        // BUG: if we returned from `iterate()` call we either processed all
        // input _or_ ctx got cancelled and the iteration got interrupted.
        //
        // Simply checking if ctx.Err() is non-nil here is racy, and may
        // provide false faulty result in case when we processed all the input
        // and _then_ context got cancelled.
        //
        // On the other hand, checking here if i == len(in) as a condition of
        // completeness is incorrect, as we might pull the last element to
        // process and _then_ got interrupted by the context cancelation.
        //
        // So if iterate() doesn't return an error, one should track each
        // element processing state in `f()` call wrapper to correctly
        // distinguish cases above.
        err = ctx.Err()
    }
    return
}
Enter fullscreen mode Exit fullscreen mode

Prologue

    // Create sub-context for the dispatch loop goroutine so we can stop it
    // once the user wants to stop the iteration.
    subctx, cancel := context.WithCancel(ctx)
    defer cancel()

    // result represents input element A and the result B caused by applying
    // the given function f() to A.
    type result struct {
        a A
        b B
    }
    // loopInfo contains the dispatch loop state.
    //
    // The dispatch goroutine below signals current goroutine about the loop
    // termination by sending loopInfo to the term channel below. The current
    // goroutine uses it to understand how many elements have been dispatched
    // for processing to decide for how many results to await.
    type loopInfo struct {
        dispatched int
        err        error
    }
    // These channels are receive-only for the current goroutine and send-only
    // in the dispatch goroutine. For the sake of readability there is no type
    // constraints added.
    var (
        res  = make(chan result)
        term = make(chan loopInfo, 1)
    )

    // This wait group is used to track completion of worker goroutines started
    // by the dispatch goroutine.
    var wg sync.WaitGroup
Enter fullscreen mode Exit fullscreen mode

In previous versions of transform() we stored results according to the index of the input element in the results slice, much like bs[i] = f(as[i]). This is no longer possible with function-based input and output. So, as soon we have a result, we likely need to push() it to the user immediately. This is why we want to have two goroutines for dispatching the input elements and pushing the results back to the user -- while we are dispatching an input, we might already get an output.

Dispatch goroutine

    // Start the dispatch goroutine. Its purpose is to control the worker
    // goroutines, dispatch input elements among the workers, and eventually
    // signal the current goroutine about the dispatch loop termination.
    go func() {
        // wrk is a channel of input elements. It is send-only for the dispatch
        // gorouine and receive-only for the worker goroutines.
        wrk := make(chan A)

        var loop loopInfo
        defer func() {
            // Signal the workers there are no more elements to dispatch.
            close(wrk)
            // Report the dispatch loop state to the parent goroutine.
            term <- loop
        }()

        var workersCount int
        // We use a _closed_ channel here to make the select below to always be
        // able to receive from it and start up to the given parallelism number
        // of goroutines. Once workersCount == parallelism, we set the variable
        // to nil so that the select cannot read from it after.
        //
        // This is needed to:
        // - Support the special case when parallelism is 0, so that there are
        //   no limits on the number of workers.
        // - Awoid wasting time "corking" the semaphore channel while waiting
        //   for all started goroutines to complete, especially if given a
        //   large parallelism value.
        sem := make(chan struct{})
        close(sem)
Enter fullscreen mode Exit fullscreen mode

Dispatch loop

        for {
            if err := subctx.Err(); err != nil {
                loop.err = err
                return
            }
            a, ok := pull()
            if !ok {
                // No more input elements.
                return
            }
            if parallelism != 0 && workersCount == parallelism {
                // Prevent starting more workers.
                sem = nil
            }
            select {
            case <-subctx.Done():
                loop.err = ctx.Err()
                return

            case wrk <- a:
                // There is an idle worker goroutine that is ready to process
                // the next element.
                loop.dispatched++
                continue

            case <-sem:
                // Being able to _receive_ a tick from the channel means we can
                // start a new worker goroutine.
                loop.dispatched++
            }

            workersCount++
            wg.Add(1)
Enter fullscreen mode Exit fullscreen mode

Worker goroutine

            go func(a A) {
                defer wg.Done()

                // Capture the subctx from the topmost scope. This prevents
                // overriding it if the given prepare() function is not nil.
                subctx := subctx
                if prepare != nil {
                    var cancel context.CancelFunc
                    subctx, cancel = prepare(subctx)
                    defer cancel()
                }
                for {
                    r := result{a: a}
                    r.b = f(subctx, a)
                    select {
                    case res <- r:
                    case <-subctx.Done():
                        // If the context is cancelled, it means no more
                        // results are expected.
                        return
                    }
                    var ok bool
                    a, ok = <-wrk
                    if !ok {
                        break
                    }
                }
            }(a)
Enter fullscreen mode Exit fullscreen mode

Results collection

}
    }()

collect:
    // Wait for the results sent by the worker goroutines.
    //
    // Note the initial -1 value for the num variable since the number of
    // elements pulled and dispatched is unknown yet. We weill be notified by
    // the dispatch gorouine once the input ends or the iteration is
    // terminated.
    for i, num := 0, -1; num == -1 || i < num; {
        select {
        case <-ctx.Done():
            // We need to explicitly handle _parent_ context cancellation here
            // because it's an external interruption for us. We ignore the
            // dispatch loop termination event and stop to receive and push
            // results unconditionally.
            if err == nil {
                err = ctx.Err()
            }
            break collect

        case res := <-res:
            if !push(res.a, res.b) {
                // The user wants to stop the iteration. Signal the dispatch
                // loop about this. Note that in this case, we ignore the term
                // channel message and not return any error.
                cancel()
                break collect
            }
            i++

        case loop := <-term:
            // Dispatch loop has now terminated, and we now know the maximum
            // number of results we need receive in this loop.
            num = loop.dispatched
            err = loop.err
        }
    }

    // NOTE: we unconditionally wait for all goroutines to complete in order to
    // return to a clean state. To avoid uninterruptable sleep here users are
    // required to respect context cancellation in the provided f().
    wg.Wait()

    return err
}
Enter fullscreen mode Exit fullscreen mode

As you can see, results are pushed back to the user in the random order -- not the way they were pulled in. This is expected and because we process them concurrently.

Here's an opinionated thought: this combination of sync.WaitGroup and sem channel is a rare example of a justified co-existence of both synchronisation mechanisms in the same code. I believe that in most cases where a channel exists, the wait group is redundant, and vice versa.

And phew, that's it! It was not easy, but it is what we want. Let's see how can we use it in the next sections.

Complete code listing
func iterate[A, B any](
    ctx context.Context,
    prepare func(context.Context) (context.Context, context.CancelFunc),
    parallelism int,
    pull func() (A, bool),
    f func(context.Context, A) B,
    push func(A, B) bool,
) (err error) {
    // Create sub-context for the dispatch loop goroutine so we can stop it
    // once the user wants to stop the iteration.
    subctx, cancel := context.WithCancel(ctx)
    defer cancel()

    // result represents input element A and the result B caused by applying
    // the given function f() to A.
    type result struct {
        a A
        b B
    }
    // loopInfo contains the dispatch loop state.
    //
    // The dispatch goroutine below signals current goroutine about the loop
    // termination by sending loopInfo to the term channel below. The current
    // goroutine uses it to understand how many elements have been dispatched
    // for processing to decide for how many results to await.
    type loopInfo struct {
        dispatched int
        err        error
    }
    // These channels are receive-only for the current goroutine and send-only
    // in the dispatch goroutine. For the sake of readability there is no type
    // constraints added.
    var (
        res  = make(chan result)
        term = make(chan loopInfo, 1)
    )

    // This wait group is used to track completion of worker goroutines started
    // by the dispatch goroutine.
    var wg sync.WaitGroup

    // Start the dispatch goroutine. Its purpose is to control the worker
    // goroutines, dispatch input elements among the workers, and eventually
    // signal the current goroutine about the dispatch loop termination.
    go func() {
        // wrk is a channel of input elements. It is send-only for the dispatch
        // gorouine and receive-only for the worker goroutines.
        wrk := make(chan A)

        var loop loopInfo
        defer func() {
            // Signal the workers there are no more elements to dispatch.
            close(wrk)
            // Report the dispatch loop state to the parent goroutine.
            term <- loop
        }()

        var workersCount int
        // We use a _closed_ channel here to make the select below to always be
        // able to receive from it and start up to the given parallelism number
        // of goroutines. Once workersCount == parallelism, we set the variable
        // to nil so that the select cannot read from it after.
        //
        // This is needed to:
        // - Support the special case when parallelism is 0, so that there are
        //   no limits on the number of workers.
        // - Awoid wasting time "corking" the semaphore channel while waiting
        //   for all started goroutines to complete, especially if given a
        //   large parallelism value.
        sem := make(chan struct{})
        close(sem)

        for {
            if err := subctx.Err(); err != nil {
                loop.err = err
                return
            }
            a, ok := pull()
            if !ok {
                // No more input elements.
                return
            }
            if parallelism != 0 && workersCount == parallelism {
                // Prevent starting more workers.
                sem = nil
            }
            select {
            case <-subctx.Done():
                loop.err = ctx.Err()
                return

            case wrk <- a:
                // There is an idle worker goroutine that is ready to process
                // the next element.
                loop.dispatched++
                continue

            case <-sem:
                // Being able to _receive_ a tick from the channel means we can
                // start a new worker goroutine.
                loop.dispatched++
            }

            workersCount++
            wg.Add(1)

            go func(a A) {
                defer wg.Done()

                // Capture the subctx from the topmost scope. This prevents
                // overriding it if the given prepare() function is not nil.
                subctx := subctx
                if prepare != nil {
                    var cancel context.CancelFunc
                    subctx, cancel = prepare(subctx)
                    defer cancel()
                }
                for {
                    r := result{a: a}
                    r.b = f(subctx, a)
                    select {
                    case res <- r:
                    case <-subctx.Done():
                        // If the context is cancelled, it means no more
                        // results are expected.
                        return
                    }
                    var ok bool
                    a, ok = <-wrk
                    if !ok {
                        break
                    }
                }
            }(a)
        }
    }()

collect:
    // Wait for the results sent by the worker goroutines.
    //
    // Note the initial -1 value for the num variable since the number of
    // elements pulled and dispatched is unknown yet. We weill be notified by
    // the dispatch gorouine once the input ends or the iteration is
    // terminated.
    for i, num := 0, -1; num == -1 || i < num; {
        select {
        case <-ctx.Done():
            // We need to explicitly handle _parent_ context cancellation here
            // because it's an external interruption for us. We ignore the
            // dispatch loop termination event and stop to receive and push
            // results unconditionally.
            if err == nil {
                err = ctx.Err()
            }
            break collect

        case res := <-res:
            if !push(res.a, res.b) {
                // The user wants to stop the iteration. Signal the dispatch
                // loop about this. Note that in this case, we ignore the term
                // channel message and not return any error.
                cancel()
                break collect
            }
            i++

        case loop := <-term:
            // Dispatch loop has now terminated, and we now know the maximum
            // number of results we need receive in this loop.
            num = loop.dispatched
            err = loop.err
        }
    }

    // NOTE: we unconditionally wait for all goroutines to complete in order to
    // return to a clean state. To avoid uninterruptable sleep here users are
    // required to respect context cancellation in the provided f().
    wg.Wait()

    return err
}
Enter fullscreen mode Exit fullscreen mode

Using iterate() to transform()

To test how the generic iteration function can solve the mapping problem let's re-implement transform() using it. It obviously now looks much shorter as we moved the concurrent iteration complexity away from it and can focus basically just on storing mapping results.

func transform[A, B any](
    ctx context.Context,
    prepare func(context.Context) (context.Context, context.CancelFunc),
    parallelism int,
    as []A,
    f func(context.Context, A) (B, error),
) (
    []B, error,
) {
    bs := make([]B, len(as))
    var (
        i    int
        err1 error
    )
    err0 := iterate(ctx, prepare, parallelism,
        func() (int, bool) {
            i++
            return i - 1, i <= len(as)
        },
        func(ctx context.Context, i int) (err error) {
            bs[i], err = f(ctx, as[i])
            return
        },
        func(i int, err error) bool {
            err1 = err
            return err == nil
        },
    )
    if err := errors.Join(err0, err1); err != nil {
        return nil, err
    }
    return bs, nil
}
Enter fullscreen mode Exit fullscreen mode

Reimplemented errgroup

To conclude the analogy with the errgroup package, let's try to implement something similar using iterate() approach.

type taskFunc func(context.Context) error
Enter fullscreen mode Exit fullscreen mode
func errgroup(ctx context.Context) (
    g func(taskFunc),
    wait func() error,
) {
    task := make(chan taskFunc)
    done := make(chan struct{})

    var (
        err     error
        failure error
    )
    go func() {
        defer close(done)

        // NOTE: we ignore the context preparation here as we don't need it. We
        // also don't limit amount of goroutines running at the same time -- we
        // want each task to start to be executed as soon as possible.
        err = iterate(ctx, nil, 0,
            func() (f taskFunc, ok bool) {
                f, ok = <-task
                return
            },
            func(ctx context.Context, f taskFunc) error {
                return f(ctx)
            },
            func(_ taskFunc, err error) bool {
                if err != nil {
                    // Cancel the group work and stop taking new tasks.
                    failure = err
                    return false
                }
                return true
            },
        )
    }()

    g = func(fn taskFunc) {
        // If wait() wasn't called yet, but a previously scheduled task has
        // failed already, we should ignore the task and avoid deadlock here.
        select {
        case task <- fn:
        case <-done:
        }
    }
    wait = func() error {
        close(task)
        <-done
        return errors.Join(err, failure)
    }
    return
}
Enter fullscreen mode Exit fullscreen mode

So the use of the function would be very similar to errgroup package:

// Create the workers group and a context which will be canceled if any of the
// tasks fails.
g, wait := errgroup(ctx)
g(func(ctx context.Context) error {
    return doSomeFun(gctx)
})
g(func(ctx context.Context) error {
    return doEvenMoreFun(gctx)
})
if err := wait(); err != nil {
    // handle error
}
Enter fullscreen mode Exit fullscreen mode

Go Iterators

Let's briefly look at the near future of Go in relation to the ideas implemented above.

With the recent (as of Go 1.22) range over functions experiment it is possible to do a usual range over functions that are compatible with sequences iterator types defined by the iter package. A quite new concept in Go which is hopefully being shipped in the future versions of Go as part of the standard library. For more information please read range over func proposal as well as predestining article on coroutines in Go by Russ Cox, which the experimental iter package is built ontop.

Adjusting the iterate() to be iter compatible is easy as pie:

func iterate[A, B any](
    ctx context.Context,
    prepare func(context.Context) (context.Context, context.CancelFunc),
    parallelism int,
    seq iter.Seq[A],
    f func(context.Context, A) B,
) iter.Seq2[A, B] {
    return func(yield func(A, B) bool) {
        pull, stop := iter.Pull(seq)
        defer stop()
        iterate(ctx, prepare, parallelism, pull, f, yield)
    }
}
Enter fullscreen mode Exit fullscreen mode

Enabling that experiment allows us to do an amazing thing -- to iterate over the results of concurrently processed elements of a sequence in the regular for loop!

// Assuming the standard library supports iterators.
seq := slices.Sequence([]int{1, 2, 3})

// Output: [1, 4, 9]
for a, b := range iterate(ctx, nil, 0, seq, square) {
    fmt.Println(a, b)
}
Enter fullscreen mode Exit fullscreen mode

Conclusion

I wish this was a part of the Go standard library.

Initially I wanted to have this first sentence to be the only content for this conclusion section, but probably at least a few words still should be said why. I believe such general purpose utilities can be much better conveyed and accepted by projects if majority of the community agrees on how the utilities designed and built. Of course we can have some libraries solving similar problems, but in my opinion the more different libraries we have the more disagreement in the community we may get about what, when and how to use them. For some cases there is nothing wrong to have widely different approaches and implementations, but for some cases it can also mean not having a complete solution at all. Very often libraries initially get born as a much more specific solution than needed to be widely adopted, and to be really general purpose solution the design, API and then implementation should be well discussed way before the actual work takes place. This is how OSS foundations solve similar problems or Go team in case of Go. Having something for such concurrent/asynchronous processing feels to be a natural evolvement after getting generic slices package and later coroutines and iterators.

Many thanks for your attention.
This article was really hard to write and took a long time to polish. I hope it was helpful!

References

Top comments (0)