Pointer Networks in TensorFlow (with sample code)

tl; dr: Deep learning networks can be applied to variable-length targets, meaning you can index into arbitrary text, time series, or any sequence of data your selfish little heart desires

I recently read a paper that described a new state-of-the-art result on the Stanford Question Answering Dataset (SQuAD). The performance (F1 of 70%) was impressive, but especially interesting was an architectural capability that I hadn’t seen when it came out a year earlier in the literature — the ability to compute variable-length probability-distributions with a fixed architecture, over arbitrary inputs.

The SQuAD task is a nice step on the way to linguistic AI; you basically get a medium-length text passage, a variety of questions about that passage, and then text answers. The catch — and what makes this task more ‘feasible’ than full-fledged Q&A — is that the answer has to be a contiguous sequence of letters or words in the original passage. In other words, the answer to any of the questions has to be a set of two pointers: one pointer to the start of the ‘answer range’ and one pointer to the end (say, from character 47 to character 65, or word 14 through word 17).

SQUAD example: note that the answer to question #2 on the right is “composite number”, which is found (verbatim) in the passage to the left

Obviously, not all questions can be answered in such a way, so this is something of an “easy subset” of all possible, relevant questions (“easy” being relative, of course). But if you actually peruse the dataset over a chilled glass of Pinot Grigio, it’s fascinating just how many meaningful questions fit into this merest haiku of a linguistic task.

Now, let’s describe the naive way that you might try to attack this problem. Typically, your output layer will be a vector that’s either a distributed representation (maybe a so-called ‘sentence vector’ or ‘thought vector’) or a one-hot representation (representing a probability distribution over the set of vector slots/dimensions).

But there’s a problem.

You have to pick an output size. You have to train that final matrix with some fixed output dimensions (whether 10 or 10,000) and that’s the size of the output vector you’re going to get.

How do we apply this to SQuAD? What if one passage is 300 words long and another is 3,000 words long? Maybe you should pick the longest length passage you’ll accept (say, 10,000) and just hope that 1.) you won’t be wasting too much computation time training shorter passages on a giant architecture and that 2.) learning will actually transfer well across different passage lengths inside this fixed-length box. Unfortunately, this approach is just as fragile as it sounds.

There’s a better way.

The state-of-the-art SQuAD paper described above (Wang & Jiang, 2016) used some common designs (heavy use of bidirectional LSTMs, alignment matrices as used in translation tasks, etc.) but also mentions a technique called Pointer Networks that is, indeed, that better way.

Here’s the basic idea, architecturally: we’ll train a fixed-size network but map it over variable-length input to get variable-length output (pointers).

To do this, we start with the sequence-to-sequence design pattern explained here, folding the input sequence (whatever the length) into a fixed-size hidden state. Then, we’ll unfold that hidden state into a series of soft ‘pointers’ — probability distributions over the input sequence. In the SQUAD example above, as well as our coming example, there are two pointers (start and end), so we unfold the hidden state twice — the first time, to get the ‘start pointer’ probability distribution over all the inputs, and the second time, to get the ‘end pointer’ probability distribution over all the inputs. (If we had a different problem that required one or three or fourteen pointers, we would have to unfold the hidden state one or three or fourteen times.)

Here’s a schematic of the whole pipeline:

From question to answer, with a bunch of scalding hot GPUs in the middle

These pointer networks are a particular example of what’s known as content-based attention — using the values of the incoming data to decide dynamically where to ‘attend’ (or ‘point’, with the pointers/indices). This can be contrasted with location-based attention, which basically says “keeping looking in position X” for the datum of interest.

Let’s try out some code on a toy problem. Pointer networks are really most relevant for recurrency-sensitive data sequences, so we’ll create one. Suppose we assume our input data is a sequence of integers between 0 and 10 (with possible duplicates) of unknown length. Each sequence always begins with low integers (random values between 1 to 5), has a run of high integers (random values between 6 to 10), then turns low again to finish (1 to 5).

For example, a sequence might be “4,1,2,3,1,1,6,9,10,8,6,3,1,1”, with the run of high integers in bold, surrounded by runs of low integers. We want to train a network that can point to these two change points — the beginning and end of the run of highs in the middle, regardless of the sequence length.

For a moment, think about why this problem is difficult:

  • First of all, there’s an invisible dividing line between ‘high’ and ‘low’ that we’re not explicitly telling the pointer network about.
  • It’s also not about the raw value of any individual element, but about the values before and after — about the context around each element.
  • And finally, we’re training the network to see the first high value, but not to see the first low value after the high segment — instead, it has to point to the last high value, before the low values start showing up again. In other words, you could never solve this problem in a streaming fashion because it would require knowledge of the future.

Here are some longer sample sequences:

Three separate low/high/low sequences, with 0-padding at the end, and the start (^1) and end (^2) pointers directly under each sequence

Note that these sequences have different lengths — the second sequence is the longest (ending in 4 1 2 2) and the third is the shortest (ending in 2 2 5 3). The pointers (^1 and ^2) point to the beginning and end of the high sequence in the middle — for the first sequence, they point to the “7” and the “6” in the high run of “7 10 6 8 10 7 8 6”, for example.

To make this problem even harder (and the solution more impressive, natch), we’re going to train the pointer network on sequences that have longer segments (longer than ten integers) of highs/lows, but then test it on sequences that have shorter segments (shorter than ten integers). So not only are we testing on sequences that were never seen in training; we’re training on schema that were never seen on training — just to see if the network has truly generalized our pattern.

Note that loss trajectories are often non-monotonic. Have faith.

During training on the longer sequences, the pointer network takes a while to break through the initial plateau of 0.11, but then quickly begins to figure out what’s up with this data, with the loss dropping to < 0.01 after about 2000 training steps.

Now, how about inference? Can training on these sequences (of 33–60 integers, with segments of 11–20 integers each) possibly help when training sequences that are only 15–30 integers?

During a typical run, on out-of-sample data, we find the correct start/end indices 96% of the time.

On sequences of completely different lengths than what we trained on.

Not bad.

You’ve gone as far as you can without running the code, young Padawan. Try it out yourself!