DEV Community

Cover image for Flat map in Python 🐍
Tomasz Urbaszek
Tomasz Urbaszek

Posted on • Updated on

Flat map in Python 🐍

Some time ago one of my java colleagues from work asked:

How do you do a flat map in python?

Well, there's no built-in flat_map function. But it can be implemented in a few different ways. But instead of just discussing how one can do a pythonic flat map I want to focus on a more interesting question:

What is the most efficient way to do a flat map in python?


Let's start with a definition. A flat map is an operation that takes a list which elements have type A and a function f of type A -> [B]. The function f is then applied to each element of the initial list and then all the results are concatenated. So type of flat_map is:

flat_map :: (t -> [a]) -> [t] -> [a]
Enter fullscreen mode Exit fullscreen mode

I think showing an example is much simpler than describing it:

id = lambda x: x
flat_map(id, [[1,2], [3,4]]) == [1, 2, 3, 4]

s = lambda x: split(x, ";")
flat_map(s, ["a,b", "c,d"]) == ["a", "b", "c", "d"]
Enter fullscreen mode Exit fullscreen mode

It's a really simple operation although, the general idea behind it is a marvelous one and is called a monad. As far as I know, there are four simple ways to implement flat_map in Python:

  • for loop and extend
  • double list comprehension
  • map and reduce
  • map and sum


Which approach is the most efficient? To answer this question I am going to measure the time of execution of each implementation in 3 cases:

a) 100 lists each with 10 integers
b) 10 000 lists each with 10 integers
c) 10 000 lists each with 10 objects (class instances)

To do this I will use this simple check_time function and identity as a function applied by the flat map:

import timeit

def id(x):
    return x

def check_time(f, arg):
    n = 50
    return 1000 * (timeit.timeit(lambda: f(id, arg), number=n) / n)  # ms
Enter fullscreen mode Exit fullscreen mode

All code I used for this article is available as gist.

Map reduce

If there's a "most functional" answer then it's the combination of map and reduce (also known as fold). It requires an import from functools (if you don't know this package or mentioned function check it!) so it's not a real one-liner.

from functools import reduce

flat_map = lambda f, xs: reduce(lambda a, b: a + b, map(f, xs))
Enter fullscreen mode Exit fullscreen mode

a) 0.1230 ms
b) 1202.4405 ms
c) 3249.0119 ms

The map operation can be substituted with a list comprehension but that will result in no spectacular improvement.

Map sum

Another functional-like solution is to use sum. I really like this approach, however, its performance degrades with the length of the input list. And when applied to nongeneric types it's awfully slow. The results are no better than previous ones.

flat_map = lambda f, xs: sum(map(f, xs), [])
Enter fullscreen mode Exit fullscreen mode

a) 0.1080 ms
b) 1150.4400 ms
c) 3296.3167 ms

For and extend

So far our flat_map implementations were quite good when applied to relatively small inputs and lousy on 10 000 length lists. Luckily this 'classic' implementation is up to 1000x faster! The only drawback of this solution is the number of lines it requires. But definitely it's a good price for truly improved performance.

def flat_map(f, xs):
    ys = []
    for x in xs:
    return ys
Enter fullscreen mode Exit fullscreen mode

a) 0.0211 ms
b) 2.4535 ms
c) 2.7347 ms

The interesting thing is that difference between applying this flat_map to int and class lists is no longer so distinctively big.

List comprehension

This is truly a pythonic approach (and also functional). The minus of this implementation is its clarity. I always have to check what is the proper order of for in double comprehension. Of course, it's just a nit.

flat_map = lambda f, xs: [y for ys in xs for y in f(ys)]
Enter fullscreen mode Exit fullscreen mode

a) 0.0372 ms
b) 4.1477 ms
c) 4.5945 ms

As we can see it performs really well! Slightly slower than previous implementation but list comprehension is one-liner! Someone may even suggest switching the list comprehension to generator expression like that:

flat_map = lambda f, xs: (y for ys in xs for y in f(ys))
Enter fullscreen mode Exit fullscreen mode

a) 0.0008 ms
b) 0.0008 ms
c) 0.0007 ms

Even better, eh? Of course, using generator expression is only applicable when we are going to iterate over the result. So, to verify its real performance we should measure code similar to this one:

ys = flat_map(id, xs)
for y in ys:
Enter fullscreen mode Exit fullscreen mode

Then we get the following result
a) 0.0537 ms
b) 5.3700 ms
c) 5.7781 ms

a) 0.1512 ms
b) 13.0335 ms
c) 13.3203 ms

So the initial improvement was not truly one as nothing was executed in fact. This shows that when measuring performance we should always remember about the context in which the code is used.


Method 100 x 10 int 10000 x 10 int 10000 x 10 class
map reduce 0.1230 ms 1202.4405 3249.0119 ms
map sum 0.1080 ms 1150.4400 ms 3296.3167 ms
for + extend 0.0211 ms 2.4535 ms 2.7347 ms ms
double list comprehension 0.0372 ms 4.1477 ms 4.5945 ms

Simple problem, many solutions but only one winner... or two of them. I think it reasonable to grant the first place to both loop+extend and list comprehension. The first one is great when performance is your main concern. The second one is a really good choice when you need a handy one-liner.

So next time you will need to apply quick flat_map in Python just use list comprehension. And most importantly, do not be tempted to use reduce or sum for lists that are relatively long!

Top comments (2)

davidcrespo profile image
David Crespo

Nice post. I think the table is mixed up. It shows map sum as the fastest.

turbaszek profile image
Tomasz Urbaszek

Thanks David! I've updated the table 🚀