loading...
Lambda Store

Timeout for sync.WaitGroup

mattiabi profile image Mattia Bianchi ・3 min read

sync.WaitGroup is a very useful concurrency primitive in Go. It's mainly used to wait until a set of operations being performed in other goroutines completes. It's similar to CountDownLatch in Java. But it lacks the ability to wait until a pre-defined timeout if some of the goroutines fail to complete in time.

CountDownLatch in Java has two await() methods; one blocks indefinitely, other one blocks up to a timeout:

/**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
 * ...
 */
public void await() throws InterruptedException {
...
}

/**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted},
 * or the specified waiting time elapses.
 * ...
 */
public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
...
}

Only option in sync.WaitGroup is to wait indefinitely until counter becomes zero.

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
...
}

When you search what people suggest about how to timeout a waitgroup, you'll see most (if not all) suggest to use select statement with timeouts.

I'll show another way of adding timeout to sync.WaitGroup by embedding it to another struct type and re-defining Add(delta int) and Done() methods. Later I'll add a new method, Await(d time.Duration) bool, to be able to wait with a timeout.


First, let's define a new struct type, TimedWaitGroup embedding sync.WaitGroup:

type TimedWaitGroup struct {
    sync.WaitGroup

    // Counter.
    // Higher 32bits: decremented by both wg.Done() and timeout.
    // Lower 32bits: decremented only by wg.Done().
    counter uint64
}

counter field is an 8-bytes integer. It's actually combination of two counters; split by higher 32-bits and lower 32-bits of an uint64. When Add(delta) is called delta will be added to the both counters atomically.

func (wg *TimedWaitGroup) Add(delta int) {
    wg.WaitGroup.Add(delta)
    d := uint32(delta)
    atomic.AddUint64(&wg.counter, combineToUint64(d, d))
}

combineToUint64(x, y uint32) function combines two uint32 integers into a single uint64 integer via uint64(x) << 32 | uint64(y).

Done() method will decrement both counters by one using an atomic CAS. And original WaitGroup.Done() will be called if and only if atomic decrement succeeds. Otherwise decrement with CAS will be retried if higher 32-bits counter is greater than zero.

func (wg *TimedWaitGroup) Done() {
    c := atomic.LoadUint64(&wg.counter)
    hc, lc := splitUint64(c)
    for hc > 0 {
        if atomic.CompareAndSwapUint64(&wg.counter, c, combineToUint64(hc-1, lc-1)) {
            wg.WaitGroup.Done()
            return
        }
        c = atomic.LoadUint64(&wg.counter)
        hc, lc = splitUint64(c)
    }
}

splitUint64(value uint64) splits an uint64 integer into two uint32 integers via uint32(value >> 32), uint32(value).

Higher 32-bits counter will be decremented by both Done() method call and when specified timeout elapses. When higher 32-bits counter reaches zero, it means either all goroutines finished or timeout elapsed. But lower 32-bits counter will be decremented only by Done() method. So lower 32-bits counter will be positive if TimedWaitGroup is completed by timeout.

// Once Await returns, further calls to Wait() will return immediately
func (wg *TimedWaitGroup) Await(d time.Duration) bool {
    time.AfterFunc(d, func() {
        c := atomic.LoadUint64(&wg.counter)
        hc, lc := splitUint64(c)
        for hc > 0 {
            if atomic.CompareAndSwapUint64(&wg.counter, c, combineToUint64(0, lc)) {
                wg.WaitGroup.Add(-int(hc))
                break
            }
            c = atomic.LoadUint64(&wg.counter)
            hc, lc = splitUint64(c)
        }
    })

    wg.Wait()
    return atomic.LoadUint64(&wg.counter) == 0
}

Await(d time.Duration) method starts a timer and registers a timeout function to complete WaitGroup if timeout elapses without counter reaches to zero. But different from Done() method, this timeout function leaves lower 32-bit counter intact, but sets higher 32-bit counter to zero using an atomic CAS.

If CAS succeeds, WaitGroup's counter is decremented by the amount of higer 32-bit counter's value. At that point WaitGroup's counter will be zero and WaitGroup.Wait() will return as if all goroutines finished.

Additionally Await() method will return false if WaitGroup is completed because of a timeout, by reading value of lower 32-bit counter. It will return true if all goroutines finish in time.

Lambda Store

Lambda Store is the first the `serverless Redis` service. In this blog, Lambda Store engineering team shares their experiences on Cloud, AWS, Kubernetes, Redis and of course Lambda Store.

Discussion

markdown guide