A simple design pattern for recurrent deep learning in TensorFlow

Dev Nag
5 min readOct 8, 2016

--

tl; dr: You can hide/encapsulate the state of arbitrary recurrent networks with a single page of code

In an ideal world, every deep learning paper proposing a new architecture would link to a readily-accessible Github repository with implemented code.

In reality, you often have to hand-code the translated equations yourself, make a bunch of assumptions, and do a lot of debugging before you get something that may or may not be related to the authors’ intent.

This process is especially fraught when dealing with recurrent architectures (aka “recurrent neural networks”): computational graphs which are DGs (directed graphs) but not DAGs (directed acyclic graphs). Recurrent architectures are especially good at modeling/generating sequential data — language, music, video, even video games — anything where you care about the order of data rather than just pure input/output mapping.

However, because we can’t directly train directed graphs with directed cycles (whew!), we have to implement and train graphs that are transformations of the original graph (going from “cyclic” to “unrolled” versions) and then use Backpropagation through time (BPTT) on these shadow models. In essence, we’re mapping connections across time to connections across space:

Like a fitness commercial except the “before” is way better-looking

Now, if you’re just using vanilla LSTM/GRU, there are off-the-shelf components that you can duct-tape together easily. That’s not the problem. The hard part is taking a new recurrent architecture and trying to code the novel graph while also handling all the unrolled state tricks without introducing new bugs.

Graves (2013), Eq 7–11

For example, suppose you found yourself perusing Alex Graves’ bodice-ripping classic from 2013, “Generating Sequences with Recurrent Neural Networks”, and wanted to implement his LSTM.

Sigh.

Everywhere you see t-1 as a subscript is yet another place (and yes, Virginia, there are 7 in that little brick of symbols) that you need to worry about recurrent state: initializing it, retrieving it from the past, and saving it for the future.

If you look at TensorFlow tutorials, you’ll see a lot of code dedicated to worrying about packing and unpacking the various recurrent states. There’s much room for error here, and a cruel and unusual intermingling of architecture and logistics. It becomes half-declarative, half-procedural…and all-fugly.

But with just a tiny amount of bookkeeping code, you can make it so much easier, and almost live the dream (!) of typing in recurrent equations and getting working TensorFlow code out the other side. We can even steal TensorFlow’s idiom for get-or-create scoped variables. Let’s briefly look at the relevant stanza:

LSTM with a Pearl Earring, 2013

Here, the bp variable is a BPTT object that’s responsible for all recurrent state. There are two interesting method calls here. bp.get_past_variable() handles both initialization from a random constant and retrieval of past state (t-1), and bp.name_variable() saves the current state for future suitors.

Look how close this code is to the raw mathematical equations — I’ve left out the shape definitions for the weight matrices and bias vectors for clarity, but for the most part it’s a 1-to-1 mapping: easy to write and easy to read.

The only mental overhead is retrieving the recurrent variable(immediately before usage) and saving it (inline with usage), all in local context. There’s no reference to this state anywhere else in the graph-building code. In fact, the rest of the code bears a striking resemblance to non-recurrent TensorFlow.

Then, to generate the shadow (unrolled) model, we just call on the bp object to generate the sequence of connected graphs with a single line:

Unrolling the user-defined graph-building function

This sequence of graphs has placeholders at the right places (where those inline constants will come back to make a dramatic cameo) and are stitched together at every bp.get_past_variable() call.

During training (or inference), there are three places where all this recurrent state must be brought back into play. First, we have to send the working state into the feed dictionary (either the initialized constants defined above, or the working state from a previous training loop), and insert the training data into the unrolled placeholders. Second, we have to define the state variables we want returned from the session.run() method. Third, we have to take the post-session.run() results and extract out the state for the future.

Training idiom — retrieve state, send state + new data in, get resulting state/data out

The BPTT object takes care of that bookkeeping as well.

Note that we’re also passing a flag (bp.DEEP) in many of the calls here, during the training phase. This is because another common design pattern of recurrent networks is that you first train the unrolled/deep network but then infer using the cyclic/shallow network (with the same, shared post-training parameters).

When we infer, we use the bp.SHALLOW flag which has a different set of placeholder variables and thus needs to manage a different state pipeline. There’s also a convenience method (copy_state_forward()) to copy the final unrolled/deep state (recurrent variables) into the cyclic/shallow network before starting inference.

Inference idiom — almost identical to the training phase, but feeding one frame of data at a time

Recurrent deep learning in TensorFlow can be — if not easy — a little bit easier.

Want to check out the code + sample usage? Say no more.

--

--

Dev Nag

DevOps, AI/ML, Strategy. Founder/CEO @ CtrlStack. Prev. Founder/CTO @ Wavefront (funded by Sequoia). Former Google engineer. Stanford math. www.ctrlstack.com