DEV Community

Cover image for Recursion in Scala
Stefan Compton
Stefan Compton

Posted on

Recursion in Scala

Recursion

"It's déjà vu all over again." - Yogi Berra

I was in an interview coding test recently, trying to solve a problem in under 5 minutes. The problem was simple enough, some kind of String manipulation, but required a loop to solve it. Since I was working in Scala, and since I believe the best way to work in Scala is the pure functional way (as much as possible), I wrote the loop as a tail recursive function.

Coding Interviews

As an aside, the format of this interview leaves a lot to be desired in my opinion.

Firstly, it is entirely unrealistic in terms of the work I was being assessed to do. When will I ever have 5 minutes to solve a problem? And could anyone trust something thrown together in such haste?
Secondly, the string manipulation in question could be more effectively done with libraries that I was not allowed to use.
Finally, when I asked for some test cases, or info to build test cases in order to start development, I was told there's no time! When you're used to working in a test-first TDD/BDD manner, it's really jarring to be told not to do that. I mean, how would I know what I was writing was remotely correct?

I'll end this extended digression by saying I understand why the interview was conducted like this, it's quick and makes the candidate talk about what they are doing, and there really aren't many great ways to assess a candidate in depth without taking half a day. But it feels like there should be a better way.

Tail Recursion

When the interview was over, the interviewer said they were very impressed with the solution and said they hadn't seen anyone solve this problem with recursion. Now given that the interview was a Scala interview, that perplexed me a bit. Were they writing Java style loops? So I thought I'd dig into recursion a bit and cover why I think my solution with recursion was the right one.

The classic example for looping and recursion is the factorial function. A factorial of number n is defined as

1 * 2 * .... * n
Enter fullscreen mode Exit fullscreen mode

So 5 factorial would be

1 * 2 * 3 * 4 * 5 = 120
Enter fullscreen mode Exit fullscreen mode

Writing this up as a loop is pretty simple. Let's use a Java while loop

public long factorial(int n) {
    long result = 1;
    int i = 1;
    while i <= n {
        result *= i;
        i++;
    }
    return result;
}
Enter fullscreen mode Exit fullscreen mode

This has the undesirable Java-ism of mutable state variables - both i and result are mutated. It also uses a 64 bit long, which will overflow pretty fast. It's probably fine for this toy example, but we can do better in Scala

def factorial(n: Int): BigInt = {
    if (n <= 1) 1
    else n * factorial (n - 1)
}
Enter fullscreen mode Exit fullscreen mode

Great, no mutation, and it's nice and concise. It even uses the BigInt (which wraps Java's BigInteger) for arbitrarily long integers - just as well since factorial of 100 is 158 digits long!

But there's still something wrong. To get a clue what it is, let's try the factorial of 20,000

factorial(20000)
...
Exception in thread "main" java.lang.StackOverflowError
...
Enter fullscreen mode Exit fullscreen mode

Uh oh! We clearly have a problem. A Stack Overflow, no less. To understand the issue, let's take a quick look at our old friend the stack.

The Stack

The Stack is the space in memory where local variables are stored, and it's main purpose is to support function calls. Each call to a function will create a new frame on the stack, and any local variables still in use will occupy that stack frame.

In our example the final line, which gives the formula for a factorial recursively is leaving the value of n on the stack to be used in the final computation. The program will keep making stack frames for n, n-1, n-2 and so on. Since stack space is limited, and the stack frame itself takes up memory, we eventually run out of space and the program halts.

Luckily for us, Scala has support for a stack optimization technique called tail recursion. If you define a recursive function in such a way that nothing is left on the stack and all we have is a final call to the recursive function, then Scala will re-use the same stack frame for each call and thus prevent a stack overflow.

Here's the tail recursive version

import scala.annotation.tailrec

@tailrec
def factorial(n: Int, accumulator: BigInt = 1): BigInt = {
    if (n <= 1) accumulator
    else factorial(n - 1, accumulator * n)
}
Enter fullscreen mode Exit fullscreen mode

The mail difference here is that, rather than leaving n on the stack to be assessed at the end of a number of function calls, we're explicitly calculating an accumulator and passing it along.

And running it allows us to see that factorial of 20,000 has 77,338 digits. Cool!

Note the annotation to say that it's tail recursive. This is a compile time convenience, if the function is changed to not be tail recursive, then you'll get a compile time issue. The compiler will a tail recursive function without the annotation.

So there we are. Recursion is a pretty basic technique in FP and, while you might not use it for exactly these purposes, they are a perfect fit for recursive data structures, like trees and tries.

Top comments (0)