Using the RNN API in TensorFlow (2/7)

Erik Hallström
2 min readNov 14, 2016

--

Dear reader,

This article has been republished at Educaora and has also been open sourced. Unfortunately TensorFlow 2.0 changed the API so it is broken for later versions. Any help to make the tutorials up to date are greatly appreciated. I also recommend you looking into PyTorch.

This post is the follow up of the article “How to build a Recurrent Neural Network in TensorFlow”, where we built a RNN from scratch, building up the computational graph manually. Now we will utilize the native TensorFlow API to simplify our script.

Simple graph creation

Remember where we made the unpacking and forward passes in the vanilla RNN?

Replace the piece of code above with this:

You may also remove the weight- and bias matrices W and b declared earlier. The inner workings of the RNN are now hidden “under the hood”. Notice the usage of split instead of unpack when assigning the x_inputs variable. The tf.nn.rnn accepts a list of inputs of shape [batch_size, input_size] , and the input_size is simply one in our case (input is just a series of scalars). Split doesn’t remove the singular dimension, but unpack does, you can read more about it here. It doesn’t really matter anyways, since we still had to reshape the inputs in our previous example before the matrix multiplication. The tf.nn.rnn unrolls the RNN and creates the graph automatically, so we can remove the for-loop. The function returns a series of previous states as well as the last state in the same shape as we did before manually, here is the printed output of these variables.

Whole program

Here is the full code:

Next step

In the next post we will improve the RNN by using another architecture called “Long short-term memory” or LSTM. Actually this is not necessary since our network already can solve our toy problem. But remember that our goal is to learn to use TensorFlow properly, not to solve the actual problem which is trivial :)

--

--