How to make your custom objects behave exactly like native types through protocol adherence
The "Native Citizen" Test
Here's how you spot a junior Python developer:
# Junior code
m1 = Money(10, "USD")
m2 = Money(5, "USD")
result = m1.add(m2) # Java flashbacks intensify
And here's a senior:
# Senior code
m1 = Money(10, "USD")
m2 = Money(5, "USD")
result = m1 + m2 # Feels like native Python
The difference isn't just aesthetics it's understanding Python's protocol-oriented design philosophy.
Python doesn't care if your class inherits from int or float. It only cares if your object behaves like a number. This is duck typing at the language level: if it implements the arithmetic protocol, it's a number. If it implements the sequence protocol, it's a sequence.
Today, we're going to make our custom classes first-class citizens of the Python ecosystem by implementing the magic methods that let them integrate seamlessly with operators, loops, and built-in functions.
The Arithmetic Negotiation: __add__ vs __radd__
Let's build a Money class that supports addition:
class Money:
def __init__(self, amount, currency):
self.amount = amount
self.currency = currency
def __add__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot add different currencies")
return Money(self.amount + other.amount, self.currency)
return NotImplemented
def __repr__(self):
return f"Money({self.amount!r}, {self.currency!r})"
m1 = Money(10, "USD")
m2 = Money(5, "USD")
print(m1 + m2) # Money(15, 'USD')
Great! But what about this?
m = Money(10, "USD")
result = m + 5 # What should this do?
We could support it:
def __add__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot add different currencies")
return Money(self.amount + other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount + other, self.currency)
return NotImplemented
Now m + 5 works. But what about 5 + m?
m = Money(10, "USD")
result = 5 + m # TypeError: unsupported operand type(s) for +: 'int' and 'Money'
The Dispatch Sequence
Here's what actually happens when you write a + b:
- Python calls
a.__add__(b) - If that returns
NotImplemented, Python callsb.__radd__(a) - If that also returns
NotImplemented, Python raisesTypeError
So when you do 5 + m:
- Python calls
(5).__add__(m)→int.__add__doesn't know aboutMoney, returnsNotImplemented - Python calls
m.__radd__(5)→ We haven't implemented this yet!
The solution is reflected operations:
class Money:
def __init__(self, amount, currency):
self.amount = amount
self.currency = currency
def __add__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot add different currencies")
return Money(self.amount + other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount + other, self.currency)
return NotImplemented
def __radd__(self, other):
# Reflected addition: other + self
# Just delegate to __add__ since addition is commutative
return self.__add__(other)
def __repr__(self):
return f"Money({self.amount!r}, {self.currency!r})"
m = Money(10, "USD")
print(m + 5) # Money(15, 'USD')
print(5 + m) # Money(15, 'USD') - now works!
The Critical Distinction: NotImplemented vs NotImplementedError
This is where juniors crash and seniors cooperate:
# WRONG - Don't do this!
def __add__(self, other):
if isinstance(other, Money):
return Money(self.amount + other.amount, self.currency)
raise NotImplementedError("Cannot add Money and " + type(other).__name__)
# RIGHT - Be a good citizen
def __add__(self, other):
if isinstance(other, Money):
return Money(self.amount + other.amount, self.currency)
return NotImplemented
What's the difference?
-
NotImplementedError: An exception. Your program crashes. You're saying "I can't do this and nobody else can either." -
NotImplemented: A sentinel value. You're saying "I don't know how to handle this, but maybe the other object does."
When you return NotImplemented, you're participating in Python's cooperative operator dispatch. You're giving the other object a chance to handle the operation.
class Discount:
def __init__(self, percent):
self.percent = percent
def __radd__(self, other):
if isinstance(other, Money):
discount_amount = other.amount * (self.percent / 100)
return Money(other.amount - discount_amount, other.currency)
return NotImplemented
def __repr__(self):
return f"Discount({self.percent}%)"
m = Money(100, "USD")
d = Discount(10)
# This works because Money.__add__ returns NotImplemented,
# so Python tries Discount.__radd__
result = m + d # Money(90.0, 'USD')
If Money.__add__ had raised NotImplementedError, this would never work.
The Full Arithmetic Protocol
For a complete numeric type, you need:
-
__add__,__radd__- Addition (a + b) -
__sub__,__rsub__- Subtraction (a - b) -
__mul__,__rmul__- Multiplication (a * b) -
__truediv__,__rtruediv__- Division (a / b) -
__floordiv__,__rfloordiv__- Floor division (a // b) -
__mod__,__rmod__- Modulo (a % b) -
__pow__,__rpow__- Exponentiation (a ** b)
And the in-place variants:
-
__iadd__- In-place addition (a += b) -
__isub__,__imul__, etc.
class Money:
# ... __init__, __repr__ ...
def __add__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot add different currencies")
return Money(self.amount + other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount + other, self.currency)
return NotImplemented
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot subtract different currencies")
return Money(self.amount - other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount - other, self.currency)
return NotImplemented
def __rsub__(self, other):
# Subtraction is NOT commutative: other - self
if isinstance(other, (int, float)):
return Money(other - self.amount, self.currency)
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float)):
return Money(self.amount * other, self.currency)
return NotImplemented
def __rmul__(self, other):
return self.__mul__(other)
The Comparison Shortcut: @total_ordering
Now we need comparison operators. All six of them:
-
__eq__(equality:==) -
__ne__(inequality:!=) -
__lt__(less than:<) -
__le__(less than or equal:<=) -
__gt__(greater than:>) -
__ge__(greater than or equal:>=)
That's a lot of boilerplate. Fortunately, Python provides a shortcut:
from functools import total_ordering
@total_ordering
class Money:
def __init__(self, amount, currency):
self.amount = amount
self.currency = currency
def __eq__(self, other):
if not isinstance(other, Money):
return NotImplemented
return self.amount == other.amount and self.currency == other.currency
def __lt__(self, other):
if not isinstance(other, Money):
return NotImplemented
if self.currency != other.currency:
raise ValueError("Cannot compare different currencies")
return self.amount < other.amount
def __repr__(self):
return f"Money({self.amount!r}, {self.currency!r})"
# Now ALL of these work:
m1 = Money(10, "USD")
m2 = Money(20, "USD")
print(m1 < m2) # True
print(m1 <= m2) # True
print(m1 > m2) # False
print(m1 >= m2) # False
print(m1 == m2) # False
print(m1 != m2) # True
The @total_ordering decorator automatically generates __le__, __gt__, __ge__, and __ne__ based on your __eq__ and __lt__ implementations.
The Mathematical Inference
How does it work? Mathematics:
-
a <= bis equivalent toa < b or a == b -
a > bis equivalent tonot (a <= b) -
a >= bis equivalent tonot (a < b) -
a != bis equivalent tonot (a == b)
The decorator generates these derived methods automatically.
The Hashability Warning (Day 2 Callback)
Remember from Day 2: if your object is hashable (implements __hash__), you must be careful with equality.
Never allow this if your object is hashable:
m = Money(10, "USD")
print(m == 10) # Should return False or NotImplemented, NEVER True
Why? Because if Money(10, "USD") == 10 returns True, but hash(Money(10, "USD")) != hash(10), you violate the hashability contract:
If
a == b, thenhash(a)must equalhash(b)
This is why our __eq__ returns NotImplemented for non-Money objects:
def __eq__(self, other):
if not isinstance(other, Money):
return NotImplemented # Let Python figure it out
return self.amount == other.amount and self.currency == other.currency
The Sequence Protocol: Free Features Through Convention
Here's something magical. Implement just two methods, and Python gives you iteration, indexing, slicing, and membership testing for free.
class Vector:
def __init__(self, *components):
self._components = list(components)
def __len__(self):
return len(self._components)
def __getitem__(self, index):
return self._components[index]
def __repr__(self):
return f"Vector{tuple(self._components)}"
v = Vector(1, 2, 3, 4, 5)
# Indexing - we implemented this explicitly
print(v[0]) # 1
print(v[-1]) # 5
# Slicing - we get this for FREE
print(v[1:3]) # [2, 3]
# Iteration - we get this for FREE
for component in v:
print(component) # 1, 2, 3, 4, 5
# Membership - we get this for FREE
print(3 in v) # True
print(10 in v) # False
# Length - we implemented this explicitly
print(len(v)) # 5
How is this possible? When you write for x in v:, Python first checks if v has an __iter__ method. If not, it falls back to the sequence protocol:
- Call
v[0], assign tox, execute loop body - Call
v[1], assign tox, execute loop body - Keep incrementing until
v[n]raisesIndexError - Catch the
IndexErrorand stop iteration
Similarly, when you write 3 in v, Python:
- Checks if
vhas a__contains__method - If not, iterates through
vcheckingitem == 3for each item
Making It More Pythonic
We can add more magic methods to make our Vector even more powerful:
class Vector:
def __init__(self, *components):
self._components = list(components)
def __len__(self):
return len(self._components)
def __getitem__(self, index):
return self._components[index]
def __setitem__(self, index, value):
self._components[index] = value
def __add__(self, other):
if isinstance(other, Vector):
if len(self) != len(other):
raise ValueError("Vectors must have same length")
return Vector(*(a + b for a, b in zip(self, other)))
return NotImplemented
def __mul__(self, scalar):
if isinstance(scalar, (int, float)):
return Vector(*(x * scalar for x in self))
return NotImplemented
def __rmul__(self, scalar):
return self.__mul__(scalar)
def __repr__(self):
return f"Vector{tuple(self._components)}"
v1 = Vector(1, 2, 3)
v2 = Vector(4, 5, 6)
print(v1 + v2) # Vector(5, 7, 9)
print(v1 * 2) # Vector(2, 4, 6)
print(3 * v1) # Vector(3, 6, 9)
# We can modify vectors
v1[0] = 10
print(v1) # Vector(10, 2, 3)
The Iterator Protocol
If you want more control over iteration, you can implement __iter__:
class Vector:
def __init__(self, *components):
self._components = list(components)
def __len__(self):
return len(self._components)
def __getitem__(self, index):
return self._components[index]
def __iter__(self):
# Return an iterator object
return iter(self._components)
def __repr__(self):
return f"Vector{tuple(self._components)}"
Now iteration uses __iter__ instead of repeated __getitem__ calls, which can be more efficient for large sequences.
The Missing Link: How defaultdict Actually Works
You've probably used collections.defaultdict:
from collections import defaultdict
counts = defaultdict(int)
counts['a'] += 1 # No KeyError!
counts['b'] += 1
print(counts) # defaultdict(<class 'int'>, {'a': 1, 'b': 1})
But how does it work? Through a special hook called __missing__.
The __missing__ Protocol
The __missing__ method is exclusive to dict subclasses. It's called automatically when a key lookup fails:
class DefaultDict(dict):
def __init__(self, default_factory):
super().__init__()
self.default_factory = default_factory
def __missing__(self, key):
# This is called when self[key] raises KeyError
if self.default_factory is None:
raise KeyError(key)
# Create the default value
value = self.default_factory()
# Store it in the dict
self[key] = value
# Return it
return value
# Now we've built our own defaultdict!
counts = DefaultDict(int)
counts['a'] += 1 # __missing__ creates 0, then +=1 makes it 1
counts['b'] += 1
print(counts) # {'a': 1, 'b': 1}
The Execution Flow
When you do counts['a'] += 1 and 'a' doesn't exist:
- Python calls
counts.__getitem__('a') - The inherited
dict.__getitem__looks for'a', doesn't find it -
dict.__getitem__callscounts.__missing__('a') -
__missing__creates0, stores it atcounts['a'], returns0 - Python increments the returned
0to get1 - Python calls
counts.__setitem__('a', 1)
Critical detail: __missing__ is only called by __getitem__. It's not called by .get():
counts = DefaultDict(int)
print(counts['x']) # 0 - __missing__ called
print(counts.get('y')) # None - __missing__ NOT called
A Practical Use Case
Let's build a case-insensitive dictionary:
class CaseInsensitiveDict(dict):
def __setitem__(self, key, value):
# Always store with lowercase keys
super().__setitem__(key.lower(), value)
def __getitem__(self, key):
# Always lookup with lowercase keys
return super().__getitem__(key.lower())
def __missing__(self, key):
# Provide a helpful error message
raise KeyError(f"Key '{key}' not found (case-insensitive)")
config = CaseInsensitiveDict()
config['Server'] = 'localhost'
print(config['SERVER']) # 'localhost'
print(config['server']) # 'localhost'
print(config['SeRvEr']) # 'localhost'
Or a tree structure that auto-creates nested dictionaries:
class TreeDict(dict):
def __missing__(self, key):
# Create a new TreeDict for missing keys
value = TreeDict()
self[key] = value
return value
tree = TreeDict()
tree['a']['b']['c'] = 42 # No KeyErrors!
print(tree) # {'a': {'b': {'c': 42}}}
The Complete Protocol Checklist
Here's your reference for making objects feel native:
Arithmetic Protocol
-
__add__,__radd__→+ -
__sub__,__rsub__→- -
__mul__,__rmul__→* -
__truediv__,__rtruediv__→/ -
__floordiv__,__rfloordiv__→// -
__mod__,__rmod__→% -
__pow__,__rpow__→**
Always return NotImplemented for unsupported types, never raise NotImplementedError.
Comparison Protocol
-
__eq__→== -
__ne__→!= -
__lt__→< -
__le__→<= -
__gt__→> -
__ge__→>=
Use @total_ordering to generate most of these from __eq__ and __lt__.
Sequence Protocol
-
__len__→len(obj) -
__getitem__→obj[key] -
__setitem__→obj[key] = value -
__delitem__→del obj[key] -
__contains__→x in obj(optional if__getitem__+__iter__exist) -
__iter__→for x in obj(optional, falls back to__getitem__)
Dictionary Protocol
- Inherit from
dict -
__missing__→ handles missing key lookups
Summary:
Today we've learned how to make custom objects integrate seamlessly with Python's syntax:
Key Principles
-
Return
NotImplemented, notNotImplementedError→ Enables cooperative dispatch -
Implement reflected operations (
__radd__, etc.) → Supports5 + objin addition toobj + 5 -
Use
@total_ordering→ Generate six comparison methods from two -
Implement
__len__+__getitem__→ Get iteration, slicing, and membership for free -
Use
__missing__in dict subclasses → Intercept missing key lookups
The Professional Template
from functools import total_ordering
@total_ordering
class Money:
__slots__ = ['_amount', '_currency']
def __init__(self, amount, currency):
object.__setattr__(self, '_amount', amount)
object.__setattr__(self, '_currency', currency)
# Representation
def __repr__(self):
return f"Money({self.amount!r}, {self.currency!r})"
# Properties
@property
def amount(self):
return self._amount
@property
def currency(self):
return self._currency
# Arithmetic
def __add__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot add different currencies")
return Money(self.amount + other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount + other, self.currency)
return NotImplemented
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, Money):
if self.currency != other.currency:
raise ValueError("Cannot subtract different currencies")
return Money(self.amount - other.amount, self.currency)
if isinstance(other, (int, float)):
return Money(self.amount - other, self.currency)
return NotImplemented
def __mul__(self, other):
if isinstance(other, (int, float)):
return Money(self.amount * other, self.currency)
return NotImplemented
def __rmul__(self, other):
return self.__mul__(other)
# Comparison (total_ordering generates <=, >, >=, !=)
def __eq__(self, other):
if not isinstance(other, Money):
return NotImplemented
return self.amount == other.amount and self.currency == other.currency
def __lt__(self, other):
if not isinstance(other, Money):
return NotImplemented
if self.currency != other.currency:
raise ValueError("Cannot compare different currencies")
return self.amount < other.amount
# Hashability
def __hash__(self):
return hash((self.amount, self.currency))
Top comments (0)