DEV Community

Ilya Voronin
Ilya Voronin

Posted on

Celery + SQS: Stop Broken Workers from Monopolizing Your Queue with Circuit Breakers

The Problem

You have 10 Celery workers processing tasks from SQS. One worker's GPU fails. Here's what happens:

  • Healthy worker: Takes 60 seconds to process a task
  • Broken worker: Fails in 0.5 seconds, immediately grabs next task

The broken worker is 100x faster at consuming (and failing) tasks. In one minute, it handles 90% of your queue and fails everything.

This is systemic failure amplification — one broken component paralyzes the entire system.

The Solution

When a worker keeps failing, stop it from taking new tasks for a timeout period.

The challenge: Celery + SQS doesn't have a built-in way to pause consumption based on worker health. There's no API to say "stop fetching messages for 60 seconds because this worker is broken."

Our approach: Patch Celery's consumption mechanism using bootsteps + shared memory for cross-process signaling.

from multiprocessing import Value
import ctypes
from celery import bootsteps
from types import MethodType
import pybreaker

# 1. Shared memory between all processes
_paused_until = Value(ctypes.c_double, 0.0)

# 2. When circuit opens, write pause timestamp
class CircuitBreakerPauseController(pybreaker.CircuitBreakerListener):
    def state_change(self, cb, old_state, new_state):
        if new_state.name == pybreaker.STATE_OPEN:
            pause_until_ts = datetime.now(UTC).timestamp() + cb.reset_timeout
            _paused_until.value = pause_until_ts

# 3. Patch can_consume() to block SQS fetching during pause
class CircuitBreakerConsumptionGate(bootsteps.StartStopStep):
    requires = ("celery.worker.consumer.tasks:Tasks",)

    def start(self, parent):
        channel_qos = parent.task_consumer.channel.qos
        original = channel_qos.can_consume

        def can_consume(self):
            pause_until_ts = _paused_until.value
            if pause_until_ts > 0.0:
                if datetime.now(UTC).timestamp() < pause_until_ts:
                    return False  # Block fetching
            return bool(original())

        channel_qos.can_consume = MethodType(can_consume, channel_qos)

# 4. Circuit breaker with listeners
breaker = pybreaker.CircuitBreaker(
    fail_max=3,
    reset_timeout=60,
    throw_new_error_on_trip=False,  # Keep original exceptions
    listeners=[CircuitBreakerPauseController()],
)

# 5. Register bootstep and wrap tasks
app.steps["consumer"].add(CircuitBreakerConsumptionGate)

@app.task(bind=True)
def process_workflow(self, workflow):
    return breaker.call(self._execute_workflow, workflow)
Enter fullscreen mode Exit fullscreen mode

How it works:

  1. Worker fails 3 times → circuit opens
  2. PauseController writes _paused_until = now + 60s to shared memory
  3. Main process checks _paused_until in can_consume()
  4. Returns Falseno SQS fetching for 60 seconds
  5. After timeout, circuit tries one task (HALF_OPEN state)
  6. Success → close circuit, failure → pause again

The pattern transforms failing workers from system-killers into isolated incidents.

Top comments (0)