Stack-safe monads

Alexander Zaidel
3 min readApr 21, 2019

--

Recursion plays a huge role in a daily life of every functional programmer. Whenever we have to write a loop we end up writing it with a recursion, because it’s allows to eliminate mutations.

However tail call elimination in scala compiler is limited to self-recursive methods. To have a good feeling what it is, lets take a look at an example of calculation the sum in a list:

Current implementation works correct, until the amount of recursive frames fits in the stack. Calling the function on a long list will blow up the stack:

sum((1 to 20000).toList) // throws StackOverflowError

One of the possible solutions is to re-implement the sum to be tail-recursive function:

It’s getting more challenging when we calculate the sum of List[Option[Int]]:

The implementation above has the same stack overflow issue when computing the sum for a long list:

sum((1 to 20000).map(Option(_)).toList) // throws StackOverflowError

Why don’t we follow the pattern and re-implement sum to be tail-recursive, like in the previous example:

Unlikely Scala compiler doesn’t know how to optimize monadic recursion.
But the biggest surprise is that there is no StackOverflowError if we replace Option with Future:

At this moment we can see, that monads aren’t equal in terms of stack safety and improper use might lead to unexpected results.

So what to do with it, should we start writing iterative loops with mutations?
This is possible, but not necessary — trampolines to the rescue. In short trampolining is an approach to exchange stack for heap.

Scala already provides necessary abstractions from the scala.util.control.TailCalls package.

Quick guidance about the use:

  • change recursive function’s return type to TailRec[T] (TailRec is a monad though);
  • base case in the recursion function has to be wrapped in a function call done:
def done[A](result: A): TailRec[A] = Done(result)
  • wrap recursive call into tailrec:
def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)

For the sake of simplicity we will implement sum function for a List[Int] first:

Acting accordingly for counting the sum in List[Option[Int]]:

Cats library encourages to solve monadic recursion problem by implementing tailRecM function for FlatMap trait:

def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B]

Where tailRecM should call itself until f(a) returns F[Right[B]]:

default tailRecM implementation for a stack-safe Monad

Documentation says:

any operation that you would write using recursive flatMap can be rewritten to use tailRecM

Our plan is to encode(List[Option[Int]], Long) => Option[Long] intotailRecM :

def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B]

Think about context F first, we are writing a function to calculate the sum of List[Option[Int]] with the result of computation to be Option[Long] thus F — Option and B — Long. While A will contain type information about input parameters — (List[Option[Int]],Long) :

Defining an instance of Monad[Option] :

Having all pieces in place:

sum((1 to 20000).map(Option(_)).toList) // it works!!!

Today we learnt that monads can be either stack-safe (like Future) or non stack-safe. We also covered two approaches to mitigate monadic recursion for non stack-safe monads: using scala built-in TailRec abstraction as well as tailRecM from cats .

--

--