DEV Community

Andrew (he/him)
Andrew (he/him)

Posted on

Flatten a Ragged List N Levels Deep in Scala

The Problem

In Scala, it's easy to flatten a nested List of Lists with the flatten method:

scala> val nestedList = List(List(1), List(2), List(3, 4))
nestedList: List[List[Int]] = List(List(1), List(2), List(3, 4))

scala> nestedList.flatten
res28: List[Int] = List(1, 2, 3, 4)
Enter fullscreen mode Exit fullscreen mode

...but what if your List has more than one, or uneven levels of nesting? What if the first element is not a List, but the rest of the elements are?

scala> val letterList = List("a", List("b"), List("c"), List("d", "e"))
letterList: List[java.io.Serializable] = List(a, List(b), List(c), List(d, e))

scala> letterList.flatten
<console>:13: error: No implicit view available from java.io.Serializable => scala.collection.GenTraversableOnce[B].
       letterList.flatten
                  ^
Enter fullscreen mode Exit fullscreen mode

Notice the types of the first and second Lists in the interpreter:

nestedList: List[List[Int]] // vs
letterList: List[java.io.Serializable]
Enter fullscreen mode Exit fullscreen mode

Scala tries to find the narrowest class which encompasses all of the data in your List in order to infer the type. In the first example, this is easy. All of our inner Lists contain only Int elements, so they must all be List[Int]s. This is more difficult in the second example, because the elements of the outermost list can be Strings as well as List[String]s. Scala finds the smallest class which encompasses both of those types of data, and that's java.io.Serializable. Using Ints again we get what might be a more understandable result:

scala> val raggedList = List(1, List(2), List(3), List(4, 5))
raggedList: List[Any] = List(1, List(2), List(3), List(4, 5))

scala> raggedList.flatten
<console>:13: error: No implicit view available from Any => scala.collection.GenTraversableOnce[B].
       raggedList.flatten
                  ^
Enter fullscreen mode Exit fullscreen mode

...this is a List[Any], where Any is the only class which can represent both Int as well as List[Int] values. This means that when we try to flatten the List, we're trying to apply a method with the signature:

def flatten[B](implicit asTraversable: Any => scala.collection.GenTraversableOnce[B]): List[B]
Enter fullscreen mode Exit fullscreen mode

This signature might be a little intimidating, but the important bit is that -- in the case of nestedList, we're trying to convert an Any to another Scala object with a generic type B, returning a List of that type, List[B]. In our raggedList example, we simply don't have a way to convert the type Any to the GenTraversableOnce type required by flatten.

The Solution

Instead of blindly using flatten, we can instead think about the structure of our data. We have a List which contains elements which are themselves either

  1. Lists
  2. anything else that's not a List

...and we want to "unwrap" the inner Lists into their constituent elements. How can we do that? Pattern matching!

def squash (list: List[Any]): List[Any] = list match {
  case Nil => Nil
  case x :: xs => { x match {
    case y :: ys => (y :: squash(ys)) ::: squash(xs)
    case _ => x :: squash(xs)
  }}
}
Enter fullscreen mode Exit fullscreen mode

The squash method above first looks at the outer List, named list. It breaks it into a head x and a tail xs. If x is itself a List, it breaks that inner list into a head y, and a tail ys. We recursively squash the inner list x until it's completely flat, then concatenate it to the rest of the squashed outer list, squash(xs).

If the head of list, x, isn't a List, we simply append it to the beginning of the squashed outer list, squash(xs).

Does it work? It does indeed!

scala> squash(raggedList)
res34: List[Any] = List(1, 2, 3, 4, 5)

scala> squash(letterList)
res35: List[Any] = List(a, b, c, d, e)

scala> val crazyList = List(1, 2, List(3), 4, List(5, List(6, 7)), List(8, List(9, List(0))))
crazyList: List[Any] = List(1, 2, List(3), 4, List(5, List(6, 7)), List(8, List(9, List(0))))

scala> squash(crazyList)
res36: List[Any] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)
Enter fullscreen mode Exit fullscreen mode

...with one caveat, of course. All of the resulting Lists are of type Any! Can you improve this solution to fix that?

Top comments (0)