TensorFlow from Scratch

TensorFlow is cool. Lets build it from scratch.

Just kidding, we can’t build all of it, mainly because I don’t know how, but we can implement a core component of it, which allows us to define a computational graph to achieve some objective.

If you just want to see the code, here it is.

Disclaimer: I’m sure this is not how TensorFlow actually does any of this under the hood, but this is just an educational exercise, so please don’t yell at me for it because it hurts my feelings. (I’m looking at you hackernews).

A Quick Primer on TensorFlow

Essentially, TensorFlow allows you to define a model in terms of how data flows through the graph, and not worry about how to actually solve this model.

Lets think about a problem, and what the ideal framework would be to solve this problem.

Lets say our problem is, given a students test score, we want to predict their IQ. And we are given a training set that is plotted below.

Then we get a new student who scored 17 on the exam, and want to predict their IQ.

I’m pretty tempted to just draw a line through this and call it a day and so the obvious choice for modelling this relationship is a linear model (Recall: y = wx + b). So, lets hack together some code that represents this:

def predict(w, x, b):
return w * x + b

However, what should we set the value for w and b? (remember, x is our data in a linear model so we cannot set the value for it). This would require us to take a look at the training data. If you are familiar with statistics, you’ll know how to do this. (If not, don’t worry about it).

However, is it possible to build a framework that would be able to automatically find the best values for a and b given only the code above and the training data?

Obviously I wouldn’t be writing this if there wasn’t such a framework, and we will build it from scratch, to see how one of these things could possibly work. However, we will need to give a bit more information. So let’s first take a look at a potential API that can be used to fit the data we have above.

We have some x:

x = PlaceholderNode(name="x")

That is multiplied with some w:

w = VariableNode(0.5, name="w")
MultiplyNode(w, x, name='multiply')

and then added to some b:

b = VariableNode(0.1, name="b")
prediction = AddNode(MultiplyNode(w, x, name='multiply'), b, name='prediction')

This was basically what our code above did. Now, the next part, we need to give a bit more information about what we actually want to happen:

This calculates how wrong our prediction is from the correct value, and then computes the squared error.

x = PlaceholderNode(name="x")
y = PlaceholderNode(name="y")
w = VariableNode(0.5, name="w")
b = VariableNode(0.1, name="b")

prediction = AddNode(MultiplyNode(w, x, name='multiply'), b, name='prediction')
# calculate the error
cost = SubtractNode(prediction, y, name=’difference’)
error = SquaredNode(difference, name=’error’)

From the code above, our framework can construct the follow computation graph.

This computation graph defines a few things, which I represented with shapes. The diamond shapes represent operations (add, subtract, multiply, squared), the circles represent parameters (a and b) and the squares represent data (x and y).

So given this information, what do we actually our framework to do? We want it to tweak the parameters so that our error is as small as possible for all training data.


I kind of don’t want to be the 100th person on the internet who writes another back propagation tutorial about what it is, but we do need to talk about what back propagation will look like through a computational graph.

If we revisit our computation graph:

Remember, in back propagation, we would want to find ∂Error / w and ∂Error / b. To solve this with our computation graph, we simply have to traverse back up the graph, from Error, back to w and b. Here is what that traversal would look like for w, starting from Error.

Now, to back propagate from Error to w, we just recursively walk up the tree, and at every node, take the derivative of the operation in the direction we want to walk. This operation is equivalent to an expansion of the chain rule.

Building TensorFlow

Ok, now lets start implementing this framework. We can start by creating classes for each of the nodes we used in the example code above:

Placeholders and Variables

Placeholders and Variables are fundamental concepts in TensorFlow, but they can often be confusing to new users.

Lets define the difference between a PlaceholderNode and a VariableNode. A PlaceholderNode is a node that is an input to the model while a VariableNode represents a parameters of the model. Our framework can only update the values of VariableNode’s. Recall in our linear model, we had 4 values:

x: training input

y: training output

a: slope

b: intercept

In this case, a and b would be VariableNode's since they are the parameters of a model while x and y are PlaceholderNode's since they are inputs to a model.

Now, we have to fill out the implementation for the nodes. Recall that in back propagation, there are two steps. A forward pass where we compute the prediction of the model, and a backward pass where we calculate the error of the prediction, and back-propagate it to the weights. Since we have a computation graph, we can compute the value of any node in the graph, by evaluating all of its dependencies.

Imagine a a graph of x + y — z:

Computation Graph for x + y - z

In this case, we can evaluate the value of the - node by working our way up the tree. First evaluating the + node, the take the value from the + node and subtract z from it.

Lets apply this strategy to our nodes by defining a compute method for each of our nodes:

And now for the API we defined above, we can call compute on any node to evaluate the computation, passing in a feed_dict to define the values of the Placeholder nodes.

x = PlaceholderNode(name="x")
y = PlaceholderNode(name="y")
m = VariableNode(0.5, name="m")
b = VariableNode(1, name="b")

prediction = AddNode(MultiplyNode(m, x, name='multiply'), b, name='prediction')
# calculate the error
cost = SubtractNode(prediction, y, name=’difference’)
error = SquaredNode(difference, name=’error’)
prediction.compute(feed_dict={x: 10, y: 4}) # (0.5) * 10 + 1 = 6
cost.compute(feed_dict={x: 10, y: 4}) # 6 - 4 = 2
error.compute(feed_dict={x: 10, y: 4}) # 2² = 4

And with this, we have computed the forward pass. In the next blog post I will discuss the implementing of the backwards pass, however, in the meantime, all the code for the forward and backwards pass, along with some a demo for a linear model, and a quadratic model can be found here.