Metaprogramming magic with Scalameta

In this article, I will show you how to do metaprogramming in Scala. We will implement memoization annotation that will rewrite our function by adding caching mechanism inside of it. Our tools will be Scalameta and Macro paradise plugin. But let’s start from the definition of metaprogramming.

Metaprogramming is a programming technique in which computer programs have the ability to treat programs as their data.

The most boring thing I can imagine it’s talking about abstract stuff. So let’s find a real life case! If you writing functional code, probably you will have a lot of pure functions, their output fully depends on parameters that they accept. So why we must evaluate that functions every time if we can just cache results for different parameters sets? And probably we want to have some kind of generic mechanism that will rewrite our functions, looks like a nice task for Scalameta!

All great things start small

First of all, we must setup a project to look around on macros. It will be default .sbt project with scalameta and macro paradise dependencies.

Scalameta is a metaprogramming toolkit for Scala that allows you to parse string literals as language AST and operate with it properly. What does it mean for us? It means that we can compose some strings and then interpolate it in real programs! It’s like eval but with type checks in compile time. Let’s add this stuff to build.sbt .

libraryDependencies += “org.scalameta” %% “scalameta” % “1.7.0”

Macro paradise is a plugin that making latest macro developments available way before they end up in future versions Scala. Let’s add it to build.sbt also.

resolvers += Resolver.sonatypeRepo("releases")
addCompilerPlugin("org.scalameta" % "paradise_2.12.2" % "3.0.0-M8")

Very important fact that we cannot use macros at the same compile scope where we defined it. So to play with stuff that we will write, let’s add Scalatest.

libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" 

So what exactly we will do?

Our plan is to implement annotation that will inject caching behavior at function, the following steps will be:

  1. Add dummy annotation that will just handle functions
  2. Implement test cases for our macros
  3. Learn how to inject terms that we want to have inside
  4. Parse function parameters and calculate hash
  5. Inject caching behavior that will work on top of the hashing result.

First step, just dummy annotation

To start working with our fancy annotation let’s create Memoize class where we will write standard code for annotation.

import scala.annotation._
import scala.meta._
@compileTimeOnly("Should be used only in compile time.")
class Memoize extends StaticAnnotation {
  inline def apply(defn: Any): Any = meta {
defn match {
case _ => abort("@Memoize must annotate an function.")
}
}
}

What’s going on here? First of all, we import annotation and metaprogramming stuff. Then we define that current class can be used in compile time only. We mark our class as StaticAnnotation and implement base apply a function that will contain our annotation behavior. Also, you can see that we have a strange defn parameter, its exactly a code with which one we will operate later.

Let’s play with our macros

To take a look at how it works, we should create a test. It will be executed in separate compile scope, so everything will work fine. Let’s emulate different use cases for our macros.

class MacrosTest extends WordSpec with Matchers {

"Function after macro transformation" should {

"cache function result" in new Context {
testFunction("test", 1, 1)
testFunction("test", 1, 1)
testFunction("test", 1, 1)
testFunction("test", 1, 1)

callsCount shouldBe 1
}

"cache function result depends on parameters" in new Context {
testFunction("test", 1, 1)
testFunction("test1", 1, 1)
testFunction("test", 2, 1)
testFunction("test", 1, 3)

callsCount shouldBe 4
}
  }

trait Context {
var callsCount = 0
    @Memoize
def testFunction(a: String, b: Int, c: Long) = {
callsCount += 1
s"$a - $b - $c"
}
  }
}

As you can see in context trait, we used our annotation on test function. It must not increment calls counter if macros implemented properly. To start working on macros, we must open sbt console and run ~ ;clean;test. This test must not compile because our macros currently not implemented.

Play hard: manipulations with syntax tree

Now we are at the moment when we can start working on annotation. The first thing that we need to write is handler for our match function. Only one type of condition that we want to handle its def so let’s just pattern match it!

defn match {
case q"def $name(...$paramss): $result = { ..$body }" =>
???
case _ =>
abort("@Memoize must annotate an function.")
}

On first look it’s weird, right? What are that q before string and all these symbols after? String interpolator q is a quasiquote. Quasiquote is a notation that helps to manipulate with Scala syntax trees. It’s very powerful tool implemented by Denys Shabalin from EPFL. With help of it, we can process language syntax trees as regular strings, cool, isn’t it?!

So in our quasiquote we are trying to match code that starts with def, have some name, parameters, result type and function body itself. So we literally pattern match syntax tree with string. Now we have variables that we extracted from the definition of function.

Also, you can see weird stuff like .. and ... before variables extractors. Probably your IDE will not do any highlight for it but it is important part of the syntax. It means that matched part must be a list of terms or list of lists of terms.

Term it’s a representation of val, var, def, and object declarations as well as packages and value parameters.

If you run it, compilation will fail, so we should fix it. Now we have all the data that we need to rebuild our function, but for start, it will be cool to just rebuild function in the same state.

defn match {
case q"def $name(...$paramss): $result = { ..$body }" =>
q"def $name(...$paramss): $result = { ..$body }"
case _ =>
abort("@Memoize must annotate an function.")
}

Here we did reverse operation. Now we use our extracted terms to build a new definition of function, with help of quasiquotes. Currently if you run your tests, they might compile and execute but still they will fail.

Just for example, let’s inject execution of println("Macro hello") before all calls in the function body.

defn match {
case q"def $name(...$paramss): $result = { ..$body }" =>
val newBody = immutable.Seq(
q"println(${"Macro hello"})"
) ++ body
    q"def $name(...$paramss): $result = { ..$newBody }"
case _ =>
abort("@Memoize must annotate an function.")
}

Now we produced term of our println function and concatenated it before all other terms of the function body. If you rerun your test, you should see a lot of messages in console. So we did it! We injected our code in compile time.

More complicated things, working with function parameters

To be sure, that we got a correct value from cache, we must build some kind of hash on top of parameters list. To do that, we must generate code that concatenates hash codes of all parameters. As you can see before $paramss in our pattern match equation we have ... , it means that $paramss is a list of lists of terms. It’s happened because we can have more that one round bracket in the function definition. Let’s write a code that will generate hash code for us.

def hashEvaluationTerm(paramss: Seq[Seq[Term.Param]]) =
paramss
.flatten
.map(_.name.value)
.map(_ + ".hashCode().toString()")
.mkString(" + ")
.parse[Term].get

In the current example, we got names of all parameters, concatenating to it method calls of hashCode() and toString(), then concatenating all parameters with +. But still, it was just a String that looks like Scala code. Now it’s time for magic, with Scalameta we can parse it as a term. So after that, we will have correct term, result of which will be a hash for our cache storage. To take a look on result, let’s inject this term to our example.

defn match {
case q"def $name(...$paramss): $result = { ..$body }" =>
val newBody = immutable.Seq(
q"println(${hashEvaluationTerm(paramss)})"
) ++ body
q"def $name(...$paramss): $result = { ..$newBody }"
case _ =>
abort("@Memoize must annotate an function.")
}

As a result, after clean and test, you may see hashing result for parameters that we put in our test function.

Put all things in the… cache

In that part of our adventure, we already have all knowledge about macros that we need. Now only implementation of caching left. To make it clear let’s define what we want to achieve in the output of our macros.

val testFunction = {
val $cache = new ConcurrentHashMap[String, Any]
(a: String, b: Int, c: Long) => {
val $hash = a.hashCode().toString() + b.hashCode().toString() + c.hashCode().toString()
if ($cache.containsKey($hash)) {
$cache.get($hash)
} else {
val $result = {
callsCount += 1
s"$a - $b - $c"
}
$cache.put($hash, $result)
$result
}
}
}

In that example, you can see that we don’t use def anymore and instead of it we have val that return anonymous function with same parameters list. That’s needed to have an inner state for the cache that we use inside our anonymous function. Inside of the function we evaluate hash with help of term that already defined in step before. With that hash we can decide, are we want to use cached result or evaluate function from scratch. Not so many things left, we just need to write this stuff with quasiquotes. Before it, we will define some constants and helper functions.

val cacheVar = q"$$cache"
val hashVar = q"$$hash"
val resultVar = q"$$result"
def patName(termName: Term.Name): Pat.Var.Term =
Pat.Var.Term(termName)

First three constants we need here to control generated variables names and probably later check that they unused in definition of function. Function with name patName it’s just helper to convert our term from one type to another. Finally our match function should looks like.

defn match {
case q"def $name(...$paramss): $result = { ..$body }" =>

// Term to check is cache contains evaluated result
val isCacheContainsResult = q"$cacheVar.containsKey($hashVar)"

// Term to get result from cache
val getResultFromCache = q"$cacheVar.get($hashVar)"

// Term to evaluate result and put it in cache
val evaluateFunction = immutable.Seq(
q"val ${patName(resultVar)} = { ..$body }",
q"$cacheVar.put($hashVar, $resultVar)",
q"$resultVar"
)

// Term to evaluate hash and get result with it
val evaluateHashAndGetResult = immutable.Seq(
q"val ${patName(hashVar)} = ${hashEvaluationTerm(paramss)}",
q"if($isCacheContainsResult) { $getResultFromCache } else { ..$evaluateFunction }"
)

// Term to build cache and anon func that will return result
val initializeCacheAndBuildFunc = immutable.Seq(
q"val ${patName(cacheVar)} = new java.util.concurrent.ConcurrentHashMap[String, Any]",
q"((..${paramss.flatten}) => { ..$evaluateHashAndGetResult })"
)

q"val ${patName(name)} = { ..$initializeCacheAndBuildFunc }"

case _ =>
abort("@Memoize must annotate an function.")
}

In example above, we just compose different parts of the syntax tree that we prepared before with quasiquotes. Finally, we return val with same name as before. This val returns a function that have same type as our base function. On this step, all our tests will pass and its means that all stuff is done!

Summary

In this articles I described basic operations that we can do with macros. So now, probably, you understand basic principles of metaprogramming. Annotation that we created is working but still, please don’t use it in production systems ;). If you will start working on your own macros, probably you will be interested in Scalameta tutorial and Quasiquotes documentation.

Hope that this article helped you to understand basics of working with macros and if you will have any questions, feel free to ask in comments.