If you're totally unfamiliar with type annotations in Python, my previous article should get you started.
In this post, I'm going to show you how to use type variables, or
TypeVars, for fun and profit.
This function accepts anything as the argument and returns it as is. How do you explain to the type checker that the return type is the same as the type of
def identity(arg): return arg
def identity(arg: Any) -> Any: return arg
If you use
Any, the type checker will not understand how this function works: as far as it's concerned, the function can return anything at all. The return type doesn't depend on the type of
def identity_int(arg: int) -> int: return arg def identity_int(arg: str) -> str: return arg def identity_list_str(arg: list[str]) -> list[str]: return arg ...
This doesn't scale well. Are you going to replicate the same function 10 times? Will you remember to keep them in sync?
What if this is a library function? You won't be able to predict all the ways people will use this function.
Type variables allow you to link several types together. This is how you can use a type variable to annotate the
from typing import TypeVar T = TypeVar("T") def identity(arg: T) -> T: return arg
Here the return type is "linked" to the parameter type: whatever you put into the function, the same thing comes out.
This is how it looks in action (in VSCode with Pylance):
Is this a well-typed function?
def triple(string: Union[str, bytes]) -> Union[str, bytes]: return string * 3
"If you pass in
str, you get
str. If you pass in
bytes, you get
bytes" -- sounds like a job for a type variable.
That's fair enough -- not all types support multiplication. We can put a restriction that our type variable should only accept
bytes (and their subclasses, of course).
AnyString = TypeVar("AnyString", str, bytes) def triple(string: AnyString) -> AnyString: return string * 3 unicode_scream = triple("A") + "!" bytes_scream = triple(b"A") + b"!"
You can also use type variables as parameters to generic types, like
def remove_falsey_from_list(items: list[T]) -> list[T]: return [item for item in items if item] def remove_falsey(items: Iterable[T]) -> Iterator[T]: for item in items: if item: yield item
Howver, this gets tricky pretty fast. I'll cover it in depth in the next article.
mypydocumentation on generic functions: https://mypy.readthedocs.io/en/stable/generics.html#generic-functions