The Problem
You have 10 Celery workers processing tasks from SQS. One worker's GPU fails. Here's what happens:
- Healthy worker: Takes 60 seconds to process a task
- Broken worker: Fails in 0.5 seconds, immediately grabs next task
The broken worker is 100x faster at consuming (and failing) tasks. In one minute, it handles 90% of your queue and fails everything.
This is systemic failure amplification — one broken component paralyzes the entire system.
The Solution
When a worker keeps failing, stop it from taking new tasks for a timeout period.
The challenge: Celery + SQS doesn't have a built-in way to pause consumption based on worker health. There's no API to say "stop fetching messages for 60 seconds because this worker is broken."
Our approach: Patch Celery's consumption mechanism using bootsteps + shared memory for cross-process signaling.
from multiprocessing import Value
import ctypes
from celery import bootsteps
from types import MethodType
import pybreaker
# 1. Shared memory between all processes
_paused_until = Value(ctypes.c_double, 0.0)
# 2. When circuit opens, write pause timestamp
class CircuitBreakerPauseController(pybreaker.CircuitBreakerListener):
def state_change(self, cb, old_state, new_state):
if new_state.name == pybreaker.STATE_OPEN:
pause_until_ts = datetime.now(UTC).timestamp() + cb.reset_timeout
_paused_until.value = pause_until_ts
# 3. Patch can_consume() to block SQS fetching during pause
class CircuitBreakerConsumptionGate(bootsteps.StartStopStep):
requires = ("celery.worker.consumer.tasks:Tasks",)
def start(self, parent):
channel_qos = parent.task_consumer.channel.qos
original = channel_qos.can_consume
def can_consume(self):
pause_until_ts = _paused_until.value
if pause_until_ts > 0.0:
if datetime.now(UTC).timestamp() < pause_until_ts:
return False # Block fetching
return bool(original())
channel_qos.can_consume = MethodType(can_consume, channel_qos)
# 4. Circuit breaker with listeners
breaker = pybreaker.CircuitBreaker(
fail_max=3,
reset_timeout=60,
throw_new_error_on_trip=False, # Keep original exceptions
listeners=[CircuitBreakerPauseController()],
)
# 5. Register bootstep and wrap tasks
app.steps["consumer"].add(CircuitBreakerConsumptionGate)
@app.task(bind=True)
def process_workflow(self, workflow):
return breaker.call(self._execute_workflow, workflow)
How it works:
- Worker fails 3 times → circuit opens
-
PauseControllerwrites_paused_until = now + 60sto shared memory - Main process checks
_paused_untilincan_consume() - Returns
False→ no SQS fetching for 60 seconds - After timeout, circuit tries one task (HALF_OPEN state)
- Success → close circuit, failure → pause again
The pattern transforms failing workers from system-killers into isolated incidents.
Top comments (0)