## DEV Community is a community of 557,027 amazing developers

We're a place where coders share, stay up-to-date and grow their careers.

# Flat map in Python 🐍

Tomasz Urbaszek Updated on ・4 min read

Some time ago one of my java colleagues from work asked:

How do you do a flat map in python?

Well, there's no built-in `flat_map` function. But it can be implemented in a few different ways. But instead of just discussing how one can do a pythonic flat map I want to focus on a more interesting question:

What is the most efficient way to do a flat map in python?

## Definition

Let's start with a definition. A flat map is an operation that takes a list which elements have type `A` and a function `f` of type `A -> [B]`. The function `f` is then applied to each element of the initial list and then all the results are concatenated. So type of `flat_map` is:

``````flat_map :: (t -> [a]) -> [t] -> [a]
``````

I think showing an example is much simpler than describing it:

``````id = lambda x: x
flat_map(id, [[1,2], [3,4]]) == [1, 2, 3, 4]

s = lambda x: split(x, ";")
flat_map(s, ["a,b", "c,d"]) == ["a", "b", "c", "d"]
``````

It's a really simple operation although, the general idea behind it is a marvelous one and is called a monad. As far as I know, there are four simple ways to implement `flat_map` in Python:

• for loop and extend
• double list comprehension
• map and reduce
• map and sum

## Measure

Which approach is the most efficient? To answer this question I am going to measure the time of execution of each implementation in 3 cases:

a) 100 lists each with 10 integers
b) 10 000 lists each with 10 integers
c) 10 000 lists each with 10 objects (class instances)

To do this I will use this simple `check_time` function and identity as a function applied by the flat map:

``````import timeit

def id(x):
return x

def check_time(f, arg):
n = 50
return 1000 * (timeit.timeit(lambda: f(id, arg), number=n) / n)  # ms
``````

## Map reduce

If there's a "most functional" answer then it's the combination of `map` and `reduce` (also known as fold). It requires an import from `functools` (if you don't know this package or mentioned function check it!) so it's not a real one-liner.

``````from functools import reduce

flat_map = lambda f, xs: reduce(lambda a, b: a + b, map(f, xs))
``````

Result:
a) 0.1230 ms
b) 1202.4405 ms
c) 3249.0119 ms

The map operation can be substituted with a list comprehension but that will result in no spectacular improvement.

## Map sum

Another functional-like solution is to use `sum`. I really like this approach, however, its performance degrades with the length of the input list. And when applied to nongeneric types it's awfully slow. The results are no better than previous ones.

``````flat_map = lambda f, xs: sum(map(f, xs), [])
``````

Result:
a) 0.1080 ms
b) 1150.4400 ms
c) 3296.3167 ms

## For and extend

So far our `flat_map` implementations were quite good when applied to relatively small inputs and lousy on 10 000 length lists. Luckily this 'classic' implementation is up to 1000x faster! The only drawback of this solution is the number of lines it requires. But definitely it's a good price for truly improved performance.

``````def flat_map(f, xs):
ys = []
for x in xs:
ys.extend(f(x))
return ys
``````

Result:
a) 0.0211 ms
b) 2.4535 ms
c) 2.7347 ms

The interesting thing is that difference between applying this `flat_map` to `int` and `class` lists is no longer so distinctively big.

## List comprehension

This is truly a pythonic approach (and also functional). The minus of this implementation is its clarity. I always have to check what is the proper order of `for` in double comprehension. Of course, it's just a nit.

``````flat_map = lambda f, xs: [y for ys in xs for y in f(ys)]
``````

Result:
a) 0.0372 ms
b) 4.1477 ms
c) 4.5945 ms

As we can see it performs really well! Slightly slower than previous implementation but list comprehension is one-liner! Someone may even suggest switching the list comprehension to generator expression like that:

``````flat_map = lambda f, xs: (y for ys in xs for y in f(ys))
``````

Result:
a) 0.0008 ms
b) 0.0008 ms
c) 0.0007 ms

Even better, eh? Of course, using generator expression is only applicable when we are going to iterate over the result. So, to verify its real performance we should measure code similar to this one:

``````ys = flat_map(id, xs)
for y in ys:
pass
``````

Then we get the following result
comprehension:
a) 0.0537 ms
b) 5.3700 ms
c) 5.7781 ms

generator:
a) 0.1512 ms
b) 13.0335 ms
c) 13.3203 ms

So the initial improvement was not truly one as nothing was executed in fact. This shows that when measuring performance we should always remember about the context in which the code is used.

## Summary

Method 100 x 10 int 10000 x 10 int 10000 x 10 class
map reduce 0.1230 ms 1202.4405 3249.0119 ms
map sum 0.1080 ms 1150.4400 ms 3296.3167 ms
for + extend 0.0211 ms 2.4535 ms 2.7347 ms ms
double list comprehension 0.0372 ms 4.1477 ms 4.5945 ms

Simple problem, many solutions but only one winner... or two of them. I think it reasonable to grant the first place to both loop+extend and list comprehension. The first one is great when performance is your main concern. The second one is a really good choice when you need a handy one-liner.

So next time you will need to apply quick `flat_map` in Python just use list comprehension. And most importantly, do not be tempted to use `reduce` or `sum` for lists that are relatively long!

## Discussion

David Crespo

Nice post. I think the table is mixed up. It shows map sum as the fastest.

Tomasz Urbaszek

Thanks David! I've updated the table 🚀