How Functional Programming can be Awesome: Tail Recursion Elimination
Tail Recursion Elimination is a very interesting feature available in Functional Programming languages, like Haskell and Scala. It makes recursive function calls almost as fast as looping.
In my latest article about Functional Programming features in Python, I said map was a bit redundant given the existence of List Comprehensions, and didn’t paint lambda Expressions in a very good light either. I feel I didn’t do Functional Programming in general any justice, since I do like it as an elegant way to structure programs. It just doesn’t go that well with Python’s style and philosophy. As an offside remark, I mentioned the lack of Tail Recursive Elimination as another controversial design decision in Python’s implementation.
I will be honest now, I wasn’t entirely sure what Tail Recursive Elimination (TRE, from now on) was when I wrote that. I knew it was an optimization that had to do with recursive function calls, and that it was present in Haskell, but not a lot more.
So I decided to compensate for that in the best way I could: by learning and writing an article about it, so this won’t happen to you!
What are recursive calls?
We say a function call is recursive when it is done inside the scope of the function being called. So basically it’s a function calling itself.
Many problems (actually any problem you can solve with loops, and a lot of those you can’t) can be solved by recursively calling a function until a certain condition is met.
For instance, here’s a Python function written in both imperative and functional style:
Both functions do the same thing in theory: given a list and an element, see if the element is present and return that as a bool. On a lower level though, the second implementation is making a lot of function calls, and not actually returning from any of them until the last one is made. Why is this a problem?
Motivation for Recursion Elimination
Since function calls take up space in our computer’s Stack, there is a hard limit to how many we can make before hitting stack overflow: filling up our whole stack. Not only that: since each function call starts by setting up the stack (pushing things to memory and other costly operations), the second code is a lot slower.
As I said before, there are some problems for which you just can’t get away with a solution that doesn’t use recursion, or at least not as elegantly.
So it would be very good if we could code our functions the second way, and make them as fast as the ones done in the first one — especially if that also allowed us to avoid getting a stack overflow.
Luckily for us, someone already found a solution to this — but first, let’s clarify something.
What is Tail Recursion?
We’ve already seen why we’d like to implement recursion in an effective way, but I’ve been talking about eliminating tail recursion, not all kinds of recursion. So what makes tail recursion special?
Tail recursion is just a particular instance of recursion, where the return value of a function is calculated as a call to itself, and nothing else.
For instance, here are two versions of the factorial function. One is tail recursive, and the other is not.
Notice how, even though the return line of the first function contains a call to itself, it also does something to its output (in this particular case computing a product) so the return value is not really the recursive call’s return value. Usually we can make a regular recursive function tail recursive through the use of an accumulator parameter, as I did in the second declaration of factorial.
Introducing Tail Recursion Elimination
The whole idea behind TRE is avoiding function calls and stack frames as much as possible, since they take time and are the key difference between recursive and iterative programs. You read that right: Functional Languages are awesome, partly, because they found a way to call less functions.
In order to understand the next part, it’s important to go back a step and understand what exactly is going on every time we do a function call.
Whether our code is compiled (as in C, or Golang) or interpreted (like Python), it always ends up in the form of Machine Language instructions. These are usually coded in Assembly or other similar languages, which represent the lowest level of abstraction, and therefore the most granular control over memory and hardware.
Here’s what happens on every function call:
- All registers -the hardware equivalent of variables, where data are stored- are pushed onto the stack (written into memory, but not in the slowest possible way).
- Your computer starts reading instructions from a different memory address (corresponding to the first line of code of the called function).
- Code is executed from that address onward, doing what the function actually does. Usually changing register values in a certain way.
- All register values are popped/retrieved back from the stack, so the function we return to has its data back.
- A return statement is run, and instructions start being read from the previous function again.
Steps two and four are costlier to run in terms of time, like most operations that deal with memory. Each push or pop usually takes over ten times what a ‘regular’ (only dealing with registers) instruction does.
However if those steps were skipped, a function could write values in a register, potentially overwriting the ones the caller function had written. Just imagine what would happen if every time you called print, all your variables were changed to arbitrary values.
However, in the particular case of a function calling itself, there are a few tricks we could use:
- We can store the memory address where the function starts, and instead of calling the function, just move the ‘memory reader’ back to it in the end.
- We can write into the registers ourselves, knowing which values the previous function was expecting to get from us, without having to use the stack to restore the previous state. We know what the ‘previous function’ is expecting because it’s exactly this same function. Not only that: we don’t even have to save and restore the registers we will not alter.
That way we can avoid pushing and popping our registers back and forth, which takes a lot of time. But that’s not all — since no actual function calls are taking place (we’re only using jump statements -moving our instruction reader-), we’re not filling our stack, and no stack overflow can ever occur. We don’t need to save previous context in the stack in the first place, because we are just returning to the same function over and over. The only context we will need to save is the one for the first ever call to our function.
So to sum up, TRE is an optimization that takes advantage of a very special case of function calls: functions calling themselves, and returning their output without any further processing. It uses the knowledge a function has about itself, so that it can write suitable values into the relevant registers, without having to restore the ones it did not make any modifications in during its run. It then just jumps to its own start when it calls itself, without having to move anything around in the stack.
Thanks to this feature, languages like Haskell can run implementations of recursive algorithms, which are vital to functional programming (especially for purely functional languages), just as fast as their imperative counterpart.
Here’s a very streamlined linear search in Haskell, see how elegantly it fits in just two lines! (To those actually good with Haskell, please forgive my bad practices or horrible code):
I hope you now have a better understanding of what TRE is all about, and maybe about functional languages in general. If there are any parts in this explanation which you think are not clear enough, or are too detailed, please let me know in the comments, as I am still learning about writing.
If you want more Programming tutorials, tips and tricks, follow me!
And please consider showing your support for my writing.