DEV Community

William Lewis
William Lewis

Posted on

Sum Types in Python

Python is a lovely language. However, when working in Python I frequently find myself missing built-in support for sum types. Languages like Haskell and Rust make this kind of thing so easy:

data Op = Add | Sub | Mul
  deriving (Show)

data Expr
  = Lit Integer
  | BinOp Op Expr Expr
  deriving (Show)

val :: Expr -> Integer
val (Lit val) = val
val (BinOp op lhs rhs) =
  let x = val lhs
      y = val rhs
   in apply op x y

apply :: Op -> Integer -> Integer -> Integer
apply Add x y = x + y
apply Sub x y = x - y
apply Mul x y = x * y

val (BinOp Add (BinOp Mul (Lit 2) (Lit 3)) (Lit 4))
-- => 10
Enter fullscreen mode Exit fullscreen mode

While Python doesn't support this kind of construction out-of-the-box, we'll see that types like Expr are nonetheless possible (and easy) to express. Furthermore, we can create a decorator that handles all of the nasty boilerplate for us. The result isn't too different from the Haskell example above:

# The `enum` decorator adds methods for constructing and matching on the
# different variants:
@enum(add=(), sub=(), mul=())
class Op:
    def apply(self, x, y):
        return self.match(
            add=lambda: x + y,
            sub=lambda: x - y,
            mul=lambda: x * y,
        )


# Recursive sum types are also supported:
@enum(lit=(int,), bin_op=lambda: (Op, Expr, Expr))
class Expr:
    def val(self):
        return self.match(
            lit=lambda value: value,
            bin_op=lambda op, lhs, rhs: op.apply(lhs.val(), rhs.val()),
        )


Expr.bin_op(
    Op.add(),
    Expr.bin_op(Op.mul(), Expr.lit(2), Expr.lit(3)),
    Expr.lit(4)
).val()
# => 10
Enter fullscreen mode Exit fullscreen mode

Representing Sum Types

We'll represent sum types using a "tagged union". This is easy to grok by example:

class Expr:
    def lit(value):
        e = Expr()
        e.tag = "lit"
        e.value = value
        return e

    def bin_op(op, lhs, rhs):
        e = Expr()
        e.tag = "bin_op"
        e.op = op
        e.lhs = lhs
        e.rhs = rhs
        return e
Enter fullscreen mode Exit fullscreen mode

Each variant is an instance of the same class (in this case Expr). Each one contains a "tag" indicating which variant it is, along with the data specific to it.

The most basic way to use an Expr is with an if-else chain:

class Expr:
    # ...
    def val(self):
        if self.tag == "lit":
            return self.value
        elif self.tag == "bin_op":
            x = self.lhs.val()
            y = self.rhs.val()
            return self.op.apply(x, y)
Enter fullscreen mode Exit fullscreen mode

However, this has a few downsides:

  • The same if-else chain is repeated everywhere an Expr is used.
  • Changing the tag's value—say from "lit" to "literal"—breaks existing code.
  • Consuming sum types requires knowing implementation details (i.e. the tag and the names of the fields used by each variant).

Implementing match

We can avoid all of these issues by exposing a single, public match method used to consume sum types:

class Expr:
    # ...
    def match(self, handlers):
        # ...
Enter fullscreen mode Exit fullscreen mode

But first we need to make the different variants a little more uniform. Instead of storing its data in various fields, each variant will now store it in a tuple named data:

class Expr:
    def lit(value):
        e = Expr()
        e.tag = "lit"
        e.data = (value,)
        return e

    def bin_op(op, lhs, rhs):
        e = Expr()
        e.tag = "bin_op"
        e.data = (op, lhs, rhs)
        return e
Enter fullscreen mode Exit fullscreen mode

This allows us to implement match:

class Expr:
    # ...
    def match(self, **handlers):
        if self.tag in handlers:
            return handlers[self.tag](*self.data)
        else:
            raise RuntimeError(f"missing handler for {self.tag}")
Enter fullscreen mode Exit fullscreen mode

In one fell swoop we've solved all of the problems noted above! As another example, and for a change of scenery, here's Rust's Option type transcribed in this fashion:

class Option:
    def some(x):
        o = Option()
        o.tag = "some"
        o.data = (x,)
        return o

    def none():
        o = Option()
        o.tag = "none"
        o.data = ()
        return o

    def match(self, **handlers):
        if self.tag in handlers:
            return handlers[self.tag](*self.data)
        else:
            raise RuntimeError(f"missing handler for {self.tag}")

    def __repr__(self):
        return self.match(
            some=lambda x: f"Option.some({repr(x)})",
            none=lambda: "Option.none()",
        )

    def __eq__(self, other):
        if not isinstance(other, Option):
            return NotImplemented
        return self.tag == other.tag and self.data == other.data

    def map(self, fn):
        return self.match(
            some=lambda x: Option.some(fn(x)),
            none=lambda: Option.none()
        )

Option.some(2).map(lambda x: x**2)
# => Option.some(4)
Enter fullscreen mode Exit fullscreen mode

As a small quality of life benefit, we can support a special wildcard or "catchall" handler in match, indicated by an underscore (_):

def match(self, **handlers):
    if self.tag in handlers:
        return handlers[self.tag](*self.data)
    elif "_" in handlers:
        return handlers["_"]()
    else:
        raise RuntimeError(f"missing handler for {self.tag}")
Enter fullscreen mode Exit fullscreen mode

This allows us to use match like:

def map(self, fn):
    return self.match(
        some=lambda x: Option.some(fn(x)),
        _=lambda: Option.none(),
    )
Enter fullscreen mode Exit fullscreen mode

Implementing enum

As the Option class illustrates, a lot of the code needed to create sum types follows the same pattern:

class Foo:
    # For each variant:
    def my_variant(bar, quux):
        # Construct an instance of the class:
        f = Foo()
        # Give the instance a distinct tag:
        f.tag = "my_variant"
        # Save the values we received:
        f.data = (bar, quux)
        return f

    # This is always the same:
    def match(self, **handlers):
        if self.tag in handlers:
            return handlers[self.tag](*self.data)
        elif "_" in handlers:
            return handlers["_"]()
        else:
            raise RuntimeError(f"missing handler for {self.tag}")
Enter fullscreen mode Exit fullscreen mode

Instead of writing this ourselves, let's write a decorator to generate these methods based on some description of the variants.

def enum(**variants):
    pass
Enter fullscreen mode Exit fullscreen mode

What kind of a description? The simplest thing would be to supply a list of variant names, but we can do a little better by also providing the types of arguments that we expect. We'd use enum to automagically enhance our Option class like this:

# Add two variants:
# - One named `some` that expects a single argument of any type.
# - One named `none` that expects no arguments.
@enum(some=(object,), none=())
class Option:
    pass
Enter fullscreen mode Exit fullscreen mode

The basic structure of enum looks like this:

def enum(**variants):
    def enhance(cls):
        # Add methods to the class cls.
        return cls

    return enhance
Enter fullscreen mode Exit fullscreen mode

It's a function that returns another function, which will be called with the class we're enhancing as its only argument. Within enhance we'll attach methods for constructing each variant, along with match.

First, match, because it's just copy pasta:

def enhance(cls):
    def match(self, **handlers):
        if self.tag in handlers:
            return handlers[self.tag](*self.data)
        elif "_" in handlers:
            return handlers["_"]()
        else:
            raise ValueError(f"missing handler for {self.tag}")

    # Add a method named "match" to the class cls, whose value is the
    # `match` function defined above:
    setattr(cls, "match", match)

    return cls
Enter fullscreen mode Exit fullscreen mode

Adding methods to construct each variant is only slightly more involved. We iterate over the variants dictionary, defining a method for each entry:

def enhance(cls):
    # ...
    for tag, sig in variants.items():
        setattr(cls, tag, make_constructor(tag, sig))

    return cls
Enter fullscreen mode Exit fullscreen mode

where make_constructor creates a constructor function for a variant with tag (and name) tag, and "type signature" sig:

def enhance(cls):
    # ...
    def make_constructor(tag, sig):
        def constructor(*data):
            # Validate the data passed to the constructor:
            if len(sig) != len(data):
                raise ValueError(f"expected {len(sig)} items, not {len(data)}")
            for x, ty in zip(data, sig):
                if not isinstance(x, ty):
                    raise TypeError(f"expected {ty} but got {repr(x)}")

            # Just a generalization of what we've seen above:
            inst = cls()
            inst.tag = tag
            inst.data = data
            return inst

        return constructor

    for tag, sig in variants.items():
        setattr(cls, tag, make_constructor(tag, sig))

    return cls
Enter fullscreen mode Exit fullscreen mode

Here's the full definition of enum for reference.

Bonus Features

More Dunder Methods

We can easily enhance our sum classes with __repr__ and __eq__ methods:

def enhance(cls):
    # ...
    def _repr(self):
        return f"{cls.__name__}.{self.tag}({', '.join(map(repr, self.data))})"

    setattr(cls, "__repr__", _repr)

    def _eq(self, other):
        if not isinstance(other, cls):
            return NotImplemented
        return self.tag == other.tag and self.data == other.data

    setattr(cls, "__eq__", _eq)

    return cls
Enter fullscreen mode Exit fullscreen mode

With enhance improved in this fashion, we can define Option with minimal cruft:

@enum(some=(object,), none=())
class Option:
    def map(self, fn):
        return self.match(
            some=lambda x: Option.some(fn(x)),
            _=lambda: Option.none(),
        )
Enter fullscreen mode Exit fullscreen mode

Recursive Definitions

Unfortunately, enum isn't (yet) up to the task of defining Expr:

@enum(add=(), sub=(), mul=())
class Op:
    pass

@enum(lit=(int,), bin_op=(Op, Expr, Expr))
class Expr:
    pass

# NameError: name 'Expr' is not defined
Enter fullscreen mode Exit fullscreen mode

We're using the class Expr before it's been defined. An easy fix here is to simply call the decorator after defining the class:

class Expr:
    pass

enum(lit=(int,), bin_op=(Op, Expr, Expr))(Expr)
Enter fullscreen mode Exit fullscreen mode

But there's a simple change we can make to support this: allow a "signature" to be a function that returns a tuple:

@enum(lit=(int,), bin_op=lambda: (Op, Expr, Expr))
class Expr:
    pass
Enter fullscreen mode Exit fullscreen mode

All this requires is a small change in make_constructor:

def make_constructor(tag, sig):
    def constructor(*data):
        nonlocal sig
        # If sig is a "thunk", thaw it out:
        if callable(sig):
            sig = sig()

        # ...
Enter fullscreen mode Exit fullscreen mode

Conclusion

Useful as it may be, our fancy new enum decorator isn't without its shortcomings. The most apparent is the inability to perform any kind of "nested" pattern matching. In Rust, we can do things like this:

fn foo<T: Debug>(x: Option<Option<T>>) {
    match x {
        Some(Some(value)) => println!("{:?}", value),
        _ => {}
    }
}
Enter fullscreen mode Exit fullscreen mode

But we're forced to perform a double match to achieve the same result:

def foo(x):
    return x.match(
        some=lambda x1: x1.match(
            some=lambda value: print(value),
            _=lambda: None
        ),
        _=lambda: None
    )
Enter fullscreen mode Exit fullscreen mode

That said, these kinds of cases seem relatively rare.

Another downside is that match requires constructing and calling lots of functions. This means it's likely much slower than the equivalent if-else chain. However, the usual rule of thumb applies here: use enum if you like its ergonomic benefits, and replace it with its "generated" code if it's too slow.

Top comments (0)