Counting at compile time: Higher Order Functions

Abstracting over arithmetic

James Phillips
The Startup
10 min readJun 15, 2019

--

Counting every which way

This blog post will deal with an experimental approach to higher order functions at compile time, at the type level. It’s motivated in part by my previous post on ternary Nat, Counting to Infinity at Compile time.

You needn’t understand the details of TNat (or its binary cousin BNat), just know that they are different implementations of the concept behind shapeless.Nat that do not use the usual Succ.

What is a typelevel function?

A typeclass which takes arguments (type parameters) and maybe returns a type (an abstract type member implemented by the typeclass’s implicits). It will almost certainly follow the aux pattern. An example using ternary TNat:

What is a higher order function?

A function which takes and/or returns a function.

What is a higher order typelevel function?

An aux-pattern typeclass which takes and/or returns an aux-pattern typeclass.

… But Why?

The specific motivating example is to unify typeclasses and algorithms between shapeless’ Nat and Numerology’s TNat. And indeed any other Nat implementations you may have lying around. Numerology implements at least three different versions, and you may have more.

Imagine you really needed to calculate the GCD of two compile-time integers, but couldn’t be bothered to implement the algorithm for your flavour of Nat. Well, shapeless have implemented it for their Nat, wouldn’t it be great if we could drop in and use it? The GCD of two numbers is always the same no matter how it’s calculated.

Or, suppose you wrote a version of Nat whose encoding was specialised for multiplication, and you wanted all multiplication to use it since it’s faster. We’d need some structure somewhere which returned which multiplication algorithm to use — a higher order function.

A non-Nat motivating example might be that you’ve implemented several different sorting algorithms for HList (bubble, quick, heap, whatever), and you want to be able to write a program which does not depend on a specific sorting implementation, and to be able to drop in replacement sorting algorithms without changing your main code. The sorting algorithm eventually used would be served by a higher order typelevel function.

Caveat

This is truly experimental code in this blog post, pushing the boundaries of the scala compiler so far that it only compiles in version 2.13.0-M3. There are two bugs in the compiler before this version which prevent the code compiling: They were both fixed in 2.13.0-M3, and then one was regressed accidentally in 2.13.0-M4, and was not fixed again. I raised an issue for it, and it looks like we have to wait till scala 3 and dotty to get it fixed.

Now let’s stop talking and get to it.

Unify Nat Backends

Our first step, specially for our Nat example, is to unify the different implementations of Nat somehow. After all, the following all represent the same number, 3:

type _3 = Succ[Succ[Succ[_0]]]    // unary
type b3 = BNat.One[BNat.One[b0]] // binary
type t3 = TNat.Zero[TNat.One[t0]] // ternary

All they are, really, are different implementations of the concept of 3. Different backends. In your normal day-to-day run-time code, do you really care how any given integer is implemented in terms of bytes? Would it make any difference to you if it was somehow ternary instead? No.

So enter Symbol.

Symbol

These are pure representations of numbers, with no implementation at all. They’re not even related to one another, except via their base type. Every piece of implementation will be added orthogonally by typelevel typeclasses. In the same way that you can add functionality to T at runtime via the typeclass pattern, we’ll be adding functionality to Symbol using typelevel typeclasses.

First, we need to map our different backends to Symbol:

SymbolMapping is a typelevel function, which transforms a Symbol (ie _3) to an implementation of a Backend, ie TNat.t3.

Here’s what it looks like in practice:

This is telling the compiler that the TNat implementation of _8 is t8. It could probably be generated by a macro, but I don’t know enough about macros in order to test that hypothesis.

We can implement the same for shapeless numbers:

Of course, this function Symbol => Backend is one-way. We also need to be able to go the other way, Backend => Symbol. Luckily we can derive that automatically:

ReverseSymbolLookup simply picks up the relevant SymbolMapping.Aux and flips the argument with the return type. We know this is well-defined because SymbolMapping should be 1–1 for any given backend.

[This incidentally is one of the bugs in Scala pre- version 2.13.0-M3. The restriction of the return type, T <: Backend where Backend is a free type parameter in reverseLookup, breaks]

Now we have those two things, imagine we have implemented the mapping for every TNat and Nat we’re interested in.

Operation Application

So far, I don’t think we’ve done anything controversial. We’ve defined a new version of NatSymbol — and given a typeclass for hopping between different backends, different views, of the number it represents.

Our next step may be to implement Sum, say, for Symbol. But that’s low-hanging fruit. We’re being general here. Why operate on Sum if we can operate on Function2?

To do this we first need a way of telling different functions apart. At runtime, multiplication and addition both have the same type signature, Int => Int => Int. They are separated only by function name (and implementation, but that is hidden from us). Same idea here at the typelevel with the shape of our typelevel functions. Since the shape of our abstract sum and abstract multiplication typeclasses are the same, we need one more piece of information: The operation name.

Here’s some for you:

These are completely isolated tokens we pass around with no implementation, which will operate roughly the same way a function name does. They will eventually represent the operations addition, multiplication, and greatest-common-divisor respectively.

Now note that any given named Function2 in runtime code, such as +, is conceptually actually a function, in some abstract space, from the name to the implementation. It’s Name => Function2. We give the JVM a function name, and it returns us a Function2 object holding our algorithm.

So we need to emulate this behaviour. We need a typelevel function which goes from OperationName to a typelevel Function2, whose arguments are free to operate on any implementation of Nat. This is how we will be ‘calling’ our functions by name. Just like operating in the JVM when we ‘call’ a function sum, we need to develop a way to ‘call’ a typelevel function Sum while operating in the scala compiler.

So it looks like this:

This is a typelevel function Op => Out where Out itself is a typelevel function (Interface, Interface) => Interface. Interface is the backend we’re using, e.g. shapeless.Nat or TNat. You can imagine this as specifying the type to the compiler of the typelevel function we’re returning, the same way the JVM specifies Int¹.

[Side note: The above restriction of Out’s type parameters to Interface is the second piece of scala that is broken in mainstream versions, and works only in 2.13.0-M3]

And since it’s a typelevel function, it comes with an Aux:

And here are some implementations to mull²:

They’re simple mappings. You can see here we’ve told SumOp to operate on TNat using the typelevel function ternary.Sum.Aux. GcdOp goes to shapeless’ definition since I never did implement it for TNat in the end, and MultOp goes to BNat, for completeness of the example.

We’ve passed the actual implementation ternary.Sum.Aux in to this typeclass. We have no idea how that typeclass works or how it is written, all we know is that it’s a ‘function’ (TNat, TNat) => TNat. That’s very cool!

We could of course have passed in ternary.Mult.Aux instead, since that has the same ‘type signature’, and we would have a bug whereby we would multiply Symbols whenever we wanted to add them. But we didn’t do that, so we have avoided that bug. The same problems can of course exist in normal vanilla Scala code, if you use * instead of + by accident.

How do we use it?

It’s all very well and good having an Operation2 but we’ve yet to see it do anything. How do we actually apply it to any given Symbol?

Well, in runtime code you would do this: sum(int1, int2) and you’d get back an int3. This is conceptually, in some space, a function which takes three things: An operation name and two operands. You give these three things to the JVM and it returns the result.

So, we need yet another typelevel function satisfying these criteria:

It says ‘Do this operation on these two symbols’, i.e. it represents the function of function2-application.

As usual it has an aux type, which isn’t very interesting:

And it only has one implementation — the way we actually tell the compiler how to apply an operation name to two operands. This is the big bit, hold on to your hats:

Let’s step through it. The algorithm (the implicit bit) says:

  • Find me an interface Backend and an algorithm OpImpl for the function named Op
  • Find me an implementation of the symbol S1 in terms of Backend and call it N1 (Note: At this point we have no idea what Backend is and we never will! It could be TNat or BNat or shapeless.Nat or something more exotic. We won’t ever know since the implementation is completely hidden from us. This is very cool!)
  • Find me an implementation of the symbol S2 in terms of Backend and call it N2.
  • Apply the algorithm OpImpl to our backend-ed symbols N1 and N2, and call the result N (also of type Backend)
  • Lookup the corresponding Symbol for N, call it Out and return it.

And, we’re done! This algorithm completely hides the implementation details of Backend and ternary.Sum.Aux and all the rest from the user, the same way Int’s version of + hides the implementation from the user. It’s completely generic.

Let’s take it for a spin:

The first example uses TNat under the hood, and the second example uses BNat under the hood. I cannot stress enough just how cool this is — the end user has no idea the compile-time machinery that is turning behind the scenes.

Algorithms

Of course, now we can apply one operation why shouldn’t we do several things at once?

Consider the following typelevel function:

What we want to do is take the GCD of two symbols, add 7 to it, and square the result. By now the implementation should be easy for you to imagine:

The implicit algorithm there literally just says “Take gcd, add 7, square it”. All completely free of any knowledge of any implementations used.

Here’s an example of it working:

In fact, each step uses a different backend: mult uses BNat, sum uses TNat and gcd uses shapeless.Nat.

There’s one small thing I want to really press home here: At no point in the implementation of Symbol did I write a GCD algorithm. In fact, I’ve never even read shapeless’ implementation of GCD. And yet here, we’re free to use it randomly in the middle of an unrelated concept, apparently nothing to do with shapeless.Nat in the slightest. The mere act of adding in some implicit SymbolMappings for shapeless has allowed us to lift all of shapeless.Nat functionality to Symbol and interoperability with BNat and TNat and others, without ever having to read or understand the implementations.

All we have to do is a few symbol mappings, and trust shapeless devs know what they’re doing when they write a typelevel GCD algorithm.

One final bit of coolness

As I’ve said many times by now, the algorithms we’re writing are completely free of any knowledge of the backend and algorithm used.

Imagine running GcdPlus7Squared on some input. Here’s an example, just like above:

Now, that goes to shapeless.Nat, TNat and then finally BNat.

But… Recall how we mapped operations to algorithms in the first place:

Let’s just make a tiny change and point this at TNat’s multiplication algorithm instead:

Now let’s run our toy function again:

Obviously the result is the same: We didn’t change arithmetic (and it goes to show that our BNat and TNat algorithms are consistent with each other).

But that tiny change above, from BNat to TNat, has vastly changed the hidden machinery behind our symbols. The compiler now uses TNat for multiplication! A completely different algorithm with a different efficiency profile and different strengths and weaknesses, a completely different path for the compiler with different types and different recursion, changed by altering two small words. Our public code, the function GCDPLus7Squared, is entirely unchanged.

Pull a tiny lever and the whole world turns, and the end user has no idea you did anything at all³.

The above code and concepts are available for you to play around with in the TypeChecked project Symbology. Symbology contains an example which demonstrates vast efficiency improvements when silently swapping between shapeless and TNat algorithms behind the scenes.

Due to the limitation of using a milestone version of Scala, this is sadly unusable in normal day-to-day code. But hopefully it demonstrates the power typelevel programming can have, and may act as inspiration for more typelevel libraries in future scala versions.

[1] Our new language-at-compile-time we’re building, with our own syntax for function application and everything, isn’t actually statically typed. This is because it’s not compiled. The jvm compile-time is our run-time. This new language has no corresponding compile-time. A ‘type-error’ in this language results in an end to the scalac compilation, which is the equivalent of an exception during our run-time.

[2] These implicits must be def. If you make them vals, everything stops working. I genuinely have no idea why, maybe someone smarter than me can tell me. I was very lucky that the first time I wrote this I wrote them as defs on reflex, before later tightening to vals and having everything break on me. If I’d written them as vals first I likely would have assumed the whole concept of this post was impossible.

[3] Apart from potential efficiency gains and losses. If we had instead written a TNat GCD algorithm, and swapped that in instead of using shapeless’, then our toy function would see a huge efficiency upgrade without the end user having to do a single thing. The difference between BNat and TNat multiplication at this magnitude of calculation is not large enough to see a difference in compile times.

Thanks to Jack Wheatley for the help.

--

--