Some time ago one of my java colleagues from work asked:
How do you do a flat map in python?
Well, there's no built-in
flat_map function. But it can be implemented in a few different ways. But instead of just discussing how one can do a pythonic flat map I want to focus on a more interesting question:
What is the most efficient way to do a flat map in python?
Let's start with a definition. A flat map is an operation that takes a list which elements have type
A and a function
f of type
A -> [B]. The function
f is then applied to each element of the initial list and then all the results are concatenated. So type of
flat_map :: (t -> [a]) -> [t] -> [a]
I think showing an example is much simpler than describing it:
id = lambda x: x flat_map(id, [[1,2], [3,4]]) == [1, 2, 3, 4] s = lambda x: split(x, ";") flat_map(s, ["a,b", "c,d"]) == ["a", "b", "c", "d"]
It's a really simple operation although, the general idea behind it is a marvelous one and is called a monad. As far as I know, there are four simple ways to implement
flat_map in Python:
- for loop and extend
- double list comprehension
- map and reduce
- map and sum
Which approach is the most efficient? To answer this question I am going to measure the time of execution of each implementation in 3 cases:
a) 100 lists each with 10 integers
b) 10 000 lists each with 10 integers
c) 10 000 lists each with 10 objects (class instances)
To do this I will use this simple
check_time function and identity as a function applied by the flat map:
import timeit def id(x): return x def check_time(f, arg): n = 50 return 1000 * (timeit.timeit(lambda: f(id, arg), number=n) / n) # ms
All code I used for this article is available as gist.
If there's a "most functional" answer then it's the combination of
reduce (also known as fold). It requires an import from
functools (if you don't know this package or mentioned function check it!) so it's not a real one-liner.
from functools import reduce flat_map = lambda f, xs: reduce(lambda a, b: a + b, map(f, xs))
a) 0.1230 ms
b) 1202.4405 ms
c) 3249.0119 ms
The map operation can be substituted with a list comprehension but that will result in no spectacular improvement.
Another functional-like solution is to use
sum. I really like this approach, however, its performance degrades with the length of the input list. And when applied to nongeneric types it's awfully slow. The results are no better than previous ones.
flat_map = lambda f, xs: sum(map(f, xs), )
a) 0.1080 ms
b) 1150.4400 ms
c) 3296.3167 ms
So far our
flat_map implementations were quite good when applied to relatively small inputs and lousy on 10 000 length lists. Luckily this 'classic' implementation is up to 1000x faster! The only drawback of this solution is the number of lines it requires. But definitely it's a good price for truly improved performance.
def flat_map(f, xs): ys =  for x in xs: ys.extend(f(x)) return ys
a) 0.0211 ms
b) 2.4535 ms
c) 2.7347 ms
The interesting thing is that difference between applying this
class lists is no longer so distinctively big.
This is truly a pythonic approach (and also functional). The minus of this implementation is its clarity. I always have to check what is the proper order of
for in double comprehension. Of course, it's just a nit.
flat_map = lambda f, xs: [y for ys in xs for y in f(ys)]
a) 0.0372 ms
b) 4.1477 ms
c) 4.5945 ms
As we can see it performs really well! Slightly slower than previous implementation but list comprehension is one-liner! Someone may even suggest switching the list comprehension to generator expression like that:
flat_map = lambda f, xs: (y for ys in xs for y in f(ys))
a) 0.0008 ms
b) 0.0008 ms
c) 0.0007 ms
Even better, eh? Of course, using generator expression is only applicable when we are going to iterate over the result. So, to verify its real performance we should measure code similar to this one:
ys = flat_map(id, xs) for y in ys: pass
Then we get the following result
a) 0.0537 ms
b) 5.3700 ms
c) 5.7781 ms
a) 0.1512 ms
b) 13.0335 ms
c) 13.3203 ms
So the initial improvement was not truly one as nothing was executed in fact. This shows that when measuring performance we should always remember about the context in which the code is used.
|Method||100 x 10 int||10000 x 10 int||10000 x 10 class|
|map reduce||0.1230 ms||1202.4405||3249.0119 ms|
|map sum||0.1080 ms||1150.4400 ms||3296.3167 ms|
|for + extend||0.0211 ms||2.4535 ms||2.7347 ms ms|
|double list comprehension||0.0372 ms||4.1477 ms||4.5945 ms|
Simple problem, many solutions but only one winner... or two of them. I think it reasonable to grant the first place to both loop+extend and list comprehension. The first one is great when performance is your main concern. The second one is a really good choice when you need a handy one-liner.
So next time you will need to apply quick
flat_map in Python just use list comprehension. And most importantly, do not be tempted to use
sum for lists that are relatively long!