At BigData Republic we regularly follow courses to keep our knowledge on par (such as this excellent course on functional programming by John de Goes). One of the issues briefly touched in this particular course was stack and heap safety of recursive functions. This is something which I wanted to understand better, so I decided to dive in and look into various techniques.
In this post I’ll address both stack and heap safety, but let’s first start with a simple example demonstrating the stack issue:
When we call this function,
sum1 with a large enough List, a StackOverflowError will be thrown. So, what exactly is happening here?
Exception in thread "main" java.lang.StackOverflowError
A quick reminder on memory management for an application. Each application on the JVM makes use of two memory models: the stack and the heap.
Stack — Access to stack memory is in Last-In-First-Out order. Each time a function or method is called, a block of memory is reserved on the stack. In this block the values for the arguments will be stored, such as the list argument from our function sum1. Once a function is done with its work the block will be released. Thus, stack grows and shrinks as the program proceeds. If this memory is complete used and a new block is allocated, the JVM will throw a
java.lang.StackOverFlowError. In other words, our function does not compute.
Heap — All dynamically create objects (e.g. new XXX), are created on the heap, while references to these object are kept on the stack. Access and management of the heap space use complex memory management techniques. This makes access to objects on the heap slower than accessing objects on the stack. Another complicating factor is that memory on the heap is not automatically de-allocated, so garbage collection is required. When the heap is full, the runtime will throw a
java.lang.OutOfMemoryError. As there is normally more memory available for the heap, using the heap for growing structures is less error-prone than using the stack.
So, back to above example. The problem is in this line:
case i :: is => i + sum1(is)
For each element in this list a recursive call of sum1 is made. In this recursive call, the remainder of the list is passed as an argument. This continues until we reach the end of the list. Now, the List data structure is quite efficient, so we’re not copying the full list here for every call. Nevertheless, it only takes a relative small list to cause the stack to overflow.
A first stab — Trampolining
We can make recursive functions stack safe by using a technique called trampolining. This can be achieved by using a specific Monadic implementation:
This implementation does not return the resulting Int, but wraps the result in the
scalaz.zio.IO Monad. To make this work, we have to map over the result in the recursive call, to do the Int addition. Other than that the implementation looks quite similar, including the recursion. From first sight, one would expect the stack overflow to happen here as well. Look at this line:
case i :: is => sum2(is).map(i + _)
The map can only be performed when the sum2 call returns, right. So we would expect stack filling up again. However, that’s not the case, because IO implements a technique called trampolining. Let’s take a look at the map function of
scalaz.zio.IO, to see what’s going on:
The thing to note here, is that map is not calculating a value, but building up a data structure. For each possible instance of IO (Point, Strict, Fail, and the remaining instances), a new object is created, which gets a function as argument, e.g.:
new IO.Point(() => f(io.value()))
Meaning, to get the actual value of this Point class, we’d need to call the value function from
Unit => A, which in this case would only then execute function f.
Using this approach, of building a data structure, allows for constructing optimized interpreters. To some extent this approach makes our sum2 function stack safe, as we can now grow this IO data structure as large as the heap allows. Once the final IO, representing the overall program is build up, it can be interpreted and executed. It’s quite informative to study for example the evaluation code for IO.
However, on interpreting the structure we still run the risk of having to do too many nested calls, resulting in a stack overflow. Or if this can be avoided, a heap overflow could be generated. In general however, this approach already makes recursive functions a lot saver, but we’re not there yet.
A second stab — tail recursion
Trampolining makes recursive functions stack safe but not heap safe, as each new object to build the data structure is allocated on the heap. Consequently, if the function recurses too often, an
OutOfMemoryError will occur at some point. To make recursive code both stack and heap safe, we need to play another trick card. This card is: making sure that the function is not only recursive, but also tail recursive. So, what is tail recursive and why does it matter?
Tail recursive functions do just one thing on the recursive call: returning the result of the recursive call. We can immediately see why this is desirable, because if only the result have to be returned, no state for the current context have to be kept when entering the recursive call. Consequently, tail recursive functions can be made stack constant by the compiler.
What does it mean for a function to be stack constant? If we look at the statement
i + sum(is) , we need a stack frame for the current handling of sum (to store the actual value of i). Besides that, the recursive call will take a new stack frame for the same reason. Now, a tail recursive function only returns the result of the recursive call. This means that when entering a tail recursive call, the current stack frame can be replaced with the one of the recursive call. As a result, the stack size is not growing but stays constant.
How to achieve tail recursion? This is actually pretty easy. Instead of recursively building up the final state, we directly build up the current state and pass it along:
The first thing we see, is that in the recursive case this function does indeed nothing more than passing on the result of the recursive call. Meaning, there are no further manipulations on the return value of the recursive call before it is returned as the result of the current function call. Besides that, we suddenly also need a bit of extra information to use our new
sum3 function. It needs an initial value for the accumulator. Obviously, in the case of sum, this should be
To get our initial signature back, we use an inner function. We mark this inner function with the
@tailrec annotation to tell the compiler that it should verify that the function is compiled with the proper tail call optimization.
Finally, we can apply tail recursion to the monadic version, making it stack and heap safe:
I have shown that the general solution for functional iteration, namely recursion, can lead to memory problems due to stack and heap overflows. I explained why this happened, and outlined two strategies to counter this problem:
- Trampolining: build up a data structure in the recursive call in stead of executing functions. Actual interpretation of this data structure can be delayed and implemented in an optimized manner.
- Make recursive functions tail recursive. The only thing that should happen in the recursive call is returning the result. This allows the compiler to apply tail call optimization and makes the function stack save.
The code where you need to apply these strategies is not often as straightforward as in the above example. However, if you find yourself typing in a recursive call, you now have strategies to keep memory issues at bay.