Scala tail call optimization

Matthew Michihara
2 min readJan 18, 2018

--

a tail

Consider the following recursive function that computes n!.

def factorial(n: Int): BigInt = {
if (n == 0) {
1
} else {
n * factorial(n - 1)
}
}

This is an elegant, straightforward mapping from the mathematical definition of factorial: n! = n * (n — 1)!, 0! = 1. However, for large values of n, such as n = 10000, running this results in a StackOverflowError with a gigantic stacktrace.

Exception in thread "main" java.lang.StackOverflowError
at scala.math.BigInt$.apply(BigInt.scala:39)
at scala.math.BigInt$.int2bigInt(BigInt.scala:97)
at com.fourpool.TailRec$.factorial(TailRec.scala:15)
at com.fourpool.TailRec$.factorial(TailRec.scala:15)
at com.fourpool.TailRec$.factorial(TailRec.scala:15)
...
...
...

This is unfortunate but makes sense because for each recursive step into factorial, we need to allocate a new stack frame to remember to do the n * step for each decreasing value of n.

factorial(3)
(3 * factorial(2))
(3 * (2 * factorial(1)))
(3 * (2 * (1 * factorial(0))))
(3 * (2 * (1 * 1)))
(3 * (2 * 1))
(3 * 2)
(6)

Each expanding line above is a new stack frame being allocated to keep track of what to do once we hit our recursive base case.

Consider this variant of factorial.

def factorial(n: Int, accum: BigInt = 1): BigInt = {
if (n == 0) {
accum
} else {
factorial(n - 1, n * accum)
}
}

The difference with this implementation is that it does not rely on the stack to keep track of state as we recurse down. Instead, state is passed via factorial's params, accum in this case. Once we get to the base case, we’re done. No need to backtrack up the stack.

factorial(n = 3, accum = 1)
factorial(n = 2, accum = 3)
factorial(n = 1, accum = 6)
factorial(n = 0, accum = 6)
6

This implementation is tail recursive, meaning that the last (tail) action of this method is the recursive step. The first example is not tail recursive because it’s last action is multiplication by n.

The nice thing about Scala is that tail recursive methods are automatically optimized to not allocate unnecessary stack frames. Testing this out with the same n = 10000 returns an actual result (a very long number) rather than stack overflowing.

We can increase our confidence that Scala will tail call optimize a tail recursive method by applying the @tailrec annotation. This turns un-optimizable recursive methods that we may have mistakenly thought were tail call optimizable into compile errors rather than runtime landmines. Attempting to compile the first implementation of factorial with the @tailrec annotation yields the following error:

Error:(16, 9) could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position
n * factorial(n - 1)

--

--