Using the DynamicRNN API in TensorFlow (5/7)
In the previous guide we built a multi-layered LSTM RNN. In this post we will speed it up by not splitting up our inputs and labels into a list, as done on line 41–42 in our code. You may remove these rows where
labels_series are declared. Next change the
tf.nn.rnn call on line 47 to the following:
dynamic_rnn function takes the batch inputs of shape
[batch_size, truncated_backprop_length, input_size], thus the addition of a single dimension on the end. Output will be the last state of every layer in the network as an LSTMStateTuple stored in
current_state as well as a tensor
states_series with the shape
[batch_size, truncated_backprop_length, state_size] containing the hidden state of the last layer across all time-steps.
states_series is reshaped on the second row in the code sample above to shape
[batch_size*truncated_backprop_length, state_size], we will see the reason for this shortly. You may read more about
dynamic_rnn in the documentation.
Now input this two lines below the reshaping of the
Notice that we are now only working with tensors, Python lists were a thing of the past. The calculation of the
logits and the
labels are visualized below, notice the
state_series variable that was reshaped earlier. In TensorFlow reshaping is done in C-like index order. It means that we read from the source tensor and “write” to the destination tensor with the last axis index changing fastest, and the first axis index changing slowest. The result of the reshaping will be as visualized in the figure below, where similar colors denote the same time-step, and the vertical grouped spacing of elements denote different batches.
Let’s go trough all the tensors in the figure above, first let’s start with the sizes. We have that
truncated_backprop_length=3. The tensor
states_series have shape
labels have shape
logits have shape
W2 have shape
[state_size, num_classes] and
b2 have shape
[1, num_classes]. It can be a bit tricky to keep track of all the tensors, but drawing and visualizing with colors definitely helps.
Next calculate the predictions for the visualization:
Here we actually split the tensors into lists again. This is perhaps not the best way to do it, but it’s quick and dirty, and the plot function is already expecting a list.
sparse_softmax_cross_entropy_with_logits can take the shape of our tensors! Modify the
losses calculation to this.
As we can read in the API the
logits must have the shape
[batch_size, num_classes] and
labels must have the shape
[batch_size]. But now we are treating all time-steps as elements in our batch, so it will work out as we want.
This is the whole self-contained script, just copy and run.
In the next part we will regularize the network to use dropout, making it less prone to overfitting.