Gradient Descent with Free Monads

I was playing with Free Monads in Scala recently and discovered that it could be a perfect way of doing gradient calculations in a functional programming style. Apparently, calculating gradients using Free Monads is not the best idea from the performance point of view, but it could very useful for educational and experimental purposes. It is not a big deal to start building simple neural networks if you have a way of computing gradients for an arbitrary expression in your dispose.

Free Monad is the perfect way to build any kind of abstract syntax tree (AST) representing a computation, and, at the same time, keep the computation AST decoupled from the way it is interpreted.

My goal is to demonstrate how a simple gradient calculation engine could be built using Free Monads. First, we a going to define a domain model for AST representation. Then, a Free Monad representing a computation could be defined. Finally, we will be able to compute gradients analytically and numerically using different interpreters and compare results (should be equal). In addition, we will be able to define a simple gradient descent optimizer capable of solving a simple equations defined in terms of computation Free Monad. Here is the repo with the code demonstrated here.

Computation AST Representation

We need a way to represent a computation as an AST. We can represent it as a graph where edges are tensors incoming and outgoing to/from vertices represented by operations. There are two kinds of edges: variables and constants:

And several kinds of operations representing AST graph vertices:

Having computational graph edges (tensors) and vertices(operations) defined, we can represent an arbitrary computation built off the set of pre-defined primitive operations.

I am using the term tensor here. It is just a mathematical abstraction over of a set of values with different shapes. A scalar is a 0-dimensional tensor. A vector is a 1-dimensional tensor. A matrix is a 2-dimensional tensor. And everything with a higher dimensionality is simply called n-dimensional tensor. In this example, I am using 0-dimensional tensors a.k.a scalars.

In addition, it is convenient to define two extra types:

Computation Free Monad

Next step is to define a computation Free Monad. I am using cats Scala library here:

I am not going to dive deep here in what are free monads. There are awesome articles here and here. In essence, by just having Op[A] we can lift it into the monadic context using the free monad. As a result, we have a way to combine Op[A] in a monadic style. This monadic composition is built so it is stack-safe and could be interpreted separately from the place where it is defined. Which means, in its turn, we can apply multiple interpreters to the same computation expression.

It is possible now to define a computation in terms of the computation free monad using Scala for-comprehension syntax:

The function above takes a map where the key is a variable or constant name and the value is the variable/constant itself. The expression accepts three variables (x1, x2 and x3), and one constant (c1).

Here is the visual representation of the computation:

((x1 -x2) + x3)*c1*((x1 -x2) + x3)*c1

Numerical Gradient Interpreter

Numerical gradient here is used in the sense described in this wikipedia article. When we have a computational expressions which depends on multiple input variables then a partial derivatives for each input variable could be computed. The easiest way to compute such a partial derivative is to simply feed initial set of values into the computation and get output result, then feed the same initial set of values but with the value increased by some small delta for a variable we are computing derivative for. Having two output values calculated we can subtract them and divide by the delta used on the previous step. This will be a partial derivative by definition.

We can immediately try it for a simple expression:

This seems to be working fine, but it is not really good from the performance point of view. If we have huge input tensors then we have to run interpretation twice for every single scalar (item) in a tensor. It would be much better if we could compute derivatives analytically first and then calculate derivatives in a single run by using vectorized tensor operations (more details on this here).

Analytical Gradient Interpreter

It appears that we can apply chain rule to the computational graph and compute partial derivatives for all input variables in a single interpreter run (it is known as back-propagation in applications for neural nets). Here is the interpreter doing this:

As you can see, it is has more complex implementation, but it is capable to compute all partial derivatives in one interpretation run. Let’s try it and see if it matches what was calculated by 6 distinct numerical interpreter executions:

As you can see, not only the result matches that obtained using numerical gradient interpreter, but the precision is much better and it was calculated in a single interpreter run.

Gradient Descend Optimizer

Having gradient calculation interpreters in our dispose, we can easily create an optimizer, which uses gradient descend for a function value minimization. This is very useful in solving equations or training a machine learning model for the given cost function.

Let’s immediately try it for the same expression and see if it works:

As you can see, optimizer found such a set of input variables so the expression value is almost zero, which is the minimum possible value for non-negative by definition expression.

Conclusion

I hope you also find it funny to play with partial derivatives and gradient descend expressed in a functional way using free monads. I used to implement the same approach using Python, but the implementation was more cumbersome.