Type inference for beginners — Part 1

Dhruv Rajvanshi
12 min readOct 20, 2017

--

Greetings. This is a tutorial series for gradually implementing type checking/inference.

*Also see the EDIT note at the end

We’ll start with a very simple language (simply typed lambda calculus) and gradually add features to it and by the end, we’ll have a compiler with the following features:

  • Proper error reporting
  • Parametric polymorphism (also called generics)
  • Records/Structs
  • Methods
  • Tuples (also called product types)
  • Algebraic data types (also called tagged union/sum types)
  • Pattern matching
  • Assignments
  • Blocks
  • Modules
  • Interfaces/typeclasses
  • Sub-typing
  • Higher kinded types
  • Rank-N types

If you don’t know what some or any of these things are, don’t worry, I’ll assume 0 knowledge of type theory and will explain things as we go along.

Also, to make it accessible to most people, I’ll be using Typescript as an implementation language.

If you don’t know Typescript, don’t worry, it’s basically Javascript with type annotations. You can ignore the type annotations in the code and pretend it’s Javascript.

With that out of the way, let’s get started shall we?

In this part

Well implement type inference and checking for a very simple language called lambda calculus. The entire type checker is less than 200 lines so it’s easy to digest.

Syntax

The only syntactic category in lambda calculus is expressions. An expression is the thing that can appear on the right hand side of an assignment statement in most languages. I’ve changed the actual syntax of the original lambda calculus slightly so that it would be familiar to beginners. You may skip this section after a brief glance at the snippet below if you already know lambda calculus.

So, here’s a description of the syntax

Basically, what this says is an expression can be one of the following things

  • An integer literal (e.g. 0, 1, …)
  • A variable (e.g. x, y, abc, etc)
  • A function (e.g. (x) => increment(x)); Same as javascript but with only one argument.
  • Call (e.g. f(1)) also with only one argument
  • If expression (e.g. if (isEven(x)) 1 else 2 )

There are a few things to note here.

Firstly, all functions take 1 argument and return 1 value. For functions with multiple arguments, we use a technique called “currying”. Basically, you write functions with multiple parameters like this:

  • (x) => ((y) => (add(x))(y))

Function arrow associates to the right and calls associate to the left so you can write it more concisely as

  • (x) => (y) => add(x)(y)

If you call this function with 1 argument (x), it will return another function which takes 1 argument (y) which calls add with x and the calls the result of that again with y.

Secondly, “if” is an expression instead of statement in most languages. This means it “returns” a value when evaluated, so you can write something like

  • (condition) => if (condition) 1 else 2

When called, with true, this function will return 1, otherwise, 2.

Syntax Tree

What?

You may skip this sub section if you know what syntax trees are.

Basically, any valid program that the user types into a file is represented as a “Tree” of “Nodes”. A tree is a structure which has nodes that contain information and can optionally have child nodes. A parser takes the source text and converts this into a tree structure that can be easily manipulated by the compiler. We won’t be implementing a parser right but we will describe the tree nodes for each type of expression.

It will become more clear when we se the code.

Expressions

Here’s the type representing an expression tree.

As you can see, each expression node has a “nodeType” property, which indicates what type of expression this node is, and some values associated with the corresponding type. If you don’t know typescript, the ‘|’ operator means the type Expression can have one of these structures. An expression with type “Int” has a value property containing the number it represents. A “Var” type node contains the variable name. Most of this should be self explanatory.

So, for example, here’s an expression and it’s tree object

Thankfully, we don’t have to write these trees by hand in practice as parsers convert our source text into these trees. Real compilers store a bit more information with each node like positions in original source so that when an error occurs, the user can be notified exactly where the error occurred and maybe show the snippet of offending code as well. We’ll add that later.

Types

Apart from expressions, we also need to represent types as trees. Each expression has a known type at compile time. Since there are no type annotations (yet) in our source language, the type checker that we’re writing has to figure out the type of each expression (and it’s sub expressions, recursively). The types can be simple types (like Int, Bool, etc) or function types (for example, Int -> Int is a function that takes an Int and returns an Int). Because function types contain other types, the tree type is recursive.

“Named” types are the types that have been declared. They include builtin types like Int and Bool, and in the future, they’ll also be used to represent user declared types like record types.

“Var” types are type variables. They can stand in for any type. They can be “instantiated” on use. Right now, we’ll use them to represent types that haven’t been inferred yet.

“Function” type is self explanatory. It’s the type of a function that takes a value of type “from” and returns a value of type “to”.

Typechecker

Now, to the exciting part!

So, we’ll be implementing type inference and checking using a method called Hindley-Milner. Most modern statically typed functional languages (Haskell, ML, F#) use some variation of it. Note that the main goal of HM was to implement type inference for polymorphism. Right now, we’ll ignore the polymorphism part and focus on the rest of the algorithm.

Here’s how it works. Our type checker maintains an “environment”, which is simply a mapping of variable names and their inferred type. We start with an empty environment and start traversing the expression tree recursively from top to bottom. As we infer types of variables, we add them to the environment.

At the core of our type checking is a function called “unify”. It takes two types and tries to match them. If the types don’t match, it returns a type error.

So, our inference function, using unification, matches the types of the expression tree, and it ensures that they fit.

Integers

Let’s start with the most basic type inference function and gradually add stuff to it. Here’s the starting point.

As you can see, when you call it with an expression of nodeType “Int”, you get a named type “Int”.

Variables

Now, think about variables. For variables, we need a type environment that keeps track of the types of the variables in scope. When a variable is encountered, we need to look up the variable name in the scope. If it’s found in the environment, we return the type from the environment, otherwise, we throw an error saying that variable is not in scope.

So, we need to modify our infer function to take an additional argument, which is the environment of variables in scope.

We just added a case for “Var” nodeType. This should be pretty straightforward.

Functions

Now, we’ll look at something a little bit more involved. Let’s think about the case of functions. When we see a function, we need to infer the type of it’s body. So, we need to make a recursive call to infer with the function’s body. Also, the body may refer to the parameter of the function, so it needs to be added to the environment. We don’t have to add it permanently, only for the next call, so we make a copy of the environment with one extra member which is the parameter. But here we run into a problem. The environment needs a name and a type. We have a name from the param property of the expression node, but we don’t know it’s type. Here’s where type variables come in. We simply generate a new unique type variable to stand in for the parameter’s type. Then we can add <param_name, new_type_variable> to the environment and pass the new environment and the function’s body to the infer function.

Now, during the call to infer for the function’s body, we may find the actual type for the param’s generated type variable. In general, infer function may find the actual types that these type variables in the environment stand for. Here’s an example

  • (x) => not(x)

If we know that “not” has the type Bool -> Bool, then x must have the type Bool.

So, the infer function must return an additional result along with the type of the expression. It must also return a mapping of type variables to types. This mapping is called a substitution. Be careful not to confuse substitutions with environments. A substitution maps type variables to types whereas an environment maps variables (which are expressions) to types.

So now back to the case of functions. Once we’ve called infer for the body, we’ll get back a type (which is the type of the body) and a substitution which may or may not contain the actual type of our parameter’s type variable that we added to the environment before the call. So now we have the type of the body. So the type of the function must be

  • <type_of_param> -> <type_of_body>

We need to “apply” the returned substitution to the type variable we generated. Applying a substitution to a type means replacing all (unbound, more on that in later chapters) instances of the variables present in the substitution with their values from the substitution. So, after applying the substitution to the type variable (in this case, simply replacing it with the value type from the substitution if present), we can simply return the function type from param type to body type along with the substitution we generated. Note that we should ideally remove the type variable we generated for the parameter from the substitution that we return because the scope of that type variable is limited to the function’s body. We’re not doing it here because the type variables we generate are guaranteed to be unique so there’s no chance of collisions.

That was a lot of talk so let’s look at the code.

Notice that we replaced the env parameter with a context parameter. The context contains the environment along with some other information we might need. We need to keep a count of type variables we’ve generated till now so that we can generate the next one.

We added a new function applySubstToType which recursively traverses a type node and for type variables, it looks them up in the given substitution and returns it. If not, it simply returns the type variable itself. For other types, it recursively applies the substitution to the child nodes.

Function addToContext takes a context, a variable name and a type, and adds the mapping of name and type to the context’s environment. Notice that it doesn’t mutate the original context or it’s environment. It makes a copy and returns a new context. This is because when we bind a new variable, it should only be visible in the scope it was declared in. We don’t want to add it to the global scope.

The function newTVar takes a context, increments it’s next property and returns a new type variable with name T<context.next> (T1, T2, …). Each call will generate a new type variable. This does mutate the given context because the generated names should be unique across scopes. Though, this isn’t strictly required, this makes the implementation slightly easier.

Finally, we change the signature of the infer function to take a context instead of environment and it now returns a type and a substitution. The case for functions is added which does exactly what I explained earlier.

  1. Generate a new type variable for parameter.
  2. Add the generated type to a copy of the context.
  3. Call infer with the function’s body and the new context.
  4. Apply the returned substitution to the type variable we generated in step 1.
  5. Create a new function type with the parameter’s type and body type.
  6. Return the function’s type and substitution.

Note that the cases for variables and Ints just return an empty substitution.

Now, to the final part of the puzzle.

Function calls

As you see in the type of the “Call” nodes, it has two expressions, the function to be called, and the argument.

First, we need to call infer with the function node. The we check if the returned type is a function. If not, we throw a type error saying a function was expected.

Now we need to apply the resulting substitution to the environment’s context so that the type of any type variables inferred by the last call is available to the call to infer the argument’s type. Then infer is called with this new context and the argument node.

Now, we need to match the type of the argument with the type expected by the function. This is not as simple as simply comparing the types. The types may contain type variables which may stand for other types. This is where unification comes in.

Basically, unification is a way to check if 2 types “fit” and if they do, return a substitution that makes them equal. For example, if we try to unify (Int -> A) and (Int -> Bool), we’ll get a substitution {“A”: Bool}, because “A” is a type variable, which can stand in for any type and the second type requires it to be Bool. That’s the intuitive idea behind it.

Here’s the code

First, let’s look at the unify function, that takes 2 types and returns a substitution that makes them match. There are a few cases in this function

  • When both the types are named types, and they’re the same, we return an empty substitution because they already match.
  • If one of them is a type variable, and they’re the same, we return a null substitution.
  • If one of them is a type variable, and the other type contains the same type variable, we throw an error. This is because unifying them would lead to an infinitely recursive type. For example, if A and A -> B unify, this would mean that A is A -> B, which would mean A -> B is (A -> B) -> B. You can see that this will go on forever.
  • If one of them is a type variable and the other type doesn’t contain it, we can “bind” the type variable to the other type. Because a type variable can stand in for any type, and a call to unify means both types must be same, the type variable must stand for the other type.
  • When both types are functions, we recursively unify their param types and return types and compose the resulting substitutions (using composeSubst).
  • Any other case is a type error.

Also notice that here too, at each step, when we get a new substitution, we apply it to the older types as a way to update them with the new constraints generated by calling unify.

Now, back to the infer function for call expressions. As you see here, we generate a new type variable for the return type of the function, then infer the types for the function arguments, then we unify the inferred type of the function with the type <argType> -> <newVar>. Finally, we unify the function’s param type with the actual argument’s type because you shouldn’t be able to pass, say for example an Int to a function that takes a Bool argument right? At each step, we have to apply newly generated substitutions to the earlier types. Also, we use composeSubst (* This is a correction. Earlier, I was using simple Object.assign. See EDIT note at the end) to “compose” multiple substitutions. composeSubst function applies the first substitution to the types of the second one and then combines the result with the first substitution. At the end, we just apply the combined substitutions to the function type’s return type and return it, along with the combined substitution. You should be able to see intuitively a function (A -> B), applied to an argument of type A should have the type B.

That’s pretty much all you need to know for now. Just for completeness, here’s the case for If expressions.

If expressions

It should be straightforward. We just check the condition expression’s type should unify with Bool, and the types of both branches should unify. Then we simply return the type of the first branch (could have been the second branch as well because both unify). Before looking at the following snippet, I suggest you try writing it yourself. You don’t need any more functions than those we have already defined. Just make sure you don’t throw away any substitutions.

Here’s the code.

Testing out our typechecker

You can call the infer function on expression trees to see if they match the types. Writing out bigger expressions may become slightly tedious, I wrote a few helper functions to create expression trees.

Now, we can test some stuff by calling infer by passing in an initial context. You can add some builtins like “true”, “false”, “==”, “+”, etc.

Here are some very simple tests

You can test it out with more complicated expressions to see if everything works fine.

Conclusion

I hope this gave you some idea about how type inference works. This might have been a lot to take in all at once but if you actually write the code out, you’ll start understanding it better. Play around it a bit. Maybe add a few things. Start with showing the actual expression where the type error occurred.

In the next part, we’ll add polymorphism/generic types to our system. That would allow us to express functions that work on any type.

In the mean time, I would appreciate some corrections and suggestions. I’m sure there are plenty of subtle bugs in the code ;)

EDIT: There was a bug in this part’s code. Instead of merging substitutions using Object.assign, the first one has to be applied to the types of the second one and then take the union with the first. Thanks to @mbulfone for pointing this out.

EDIT2: There was another bug in infer for the case of “Call” expression. I had forgotten to apply the substitution after unifying with the expected function type to the funcType variable. I’ve fixed the snippet.

--

--