Learning a function with a variable number of inputs with PyTorch

Fascinated by the idea that building dynamic neural networks can be constructed easily with PyTorch, I decided to give it a try.

The application I have in mind has a variable number of inputs of the same type. For variable number of inputs, recurrent or recursive neural networks have been used. However, these structures impose some ordering or hierarchy between the inputs of a given row. However, when there is no such relationship, these approaches may not be optimal.

Here is an example: given a set of points drawn from a Gaussian distribution, estimate their mean and variance. Note that for each row of a sample, the number of points can be different.

Researchers at DeepMind have recently published the idea of ‘relation networks’ which take features of a variable number of objects as inputs and treat them in an order invariant way, without having to train on permuted duplications of the original data set.

In particular, they are interested in looking at pairs of objects. One of the network structures they propose first passes the features of each the n object pairs through a network g and then aggregate the output of these n copies of the network g by summation and feed it into a second network f:

Ordering invariant network structure for pairs of objects (from arxiv:1702.05068)

The important part is that the summation actually imposes the order invariance of the pairs.

Note also that the outputs of the g networks are typically vectors (i.e. more than one value).

Inspired by this, I tried to implement a network in PyTorch which can learn to estimate the variance of a Gaussian distribution, given a variable number of points drawn from it.

Instead of using pairs of objects, we use simple one dimensional input values. Let’s start by defining how many rows we want to generate and how many points per row etc.:

and then we draw the variances to be used for the Gaussians:

Then we generate the actual points: we first need to draw random values to determine how many points each row of the sample should contain. We then generate the points themselves (for simplicity, we center all Gaussians at zero for simplicity):

We can visualize an example row:

where the blue curve shows the actual Gaussian, the orange dots represent the values drawn from this instance. The green line shows (twice) the true square root of the variance and the red and purple line correspond to the unbiased and maximum likelihood (ML) estimates respectively.

Now let’s define the network structure. The constructor takes several parameters defining how many layers the network should have on the ‘input’ (f) and ‘output’ (g) side and how wide these layers should be:

The core of the code resides in the network’s forward method:

You’ll notice the line:

output = h.sum(0) / len(thisPoints)

which aggregates the output tensors of the network at the input side.

For the sake of simplicity, we do not use any regularization such as dropout layers etc. Also, the size of the network was put in by hand and do not come from some more rigorous procedure such as k-fold cross validation etc.

The code for preparing the training is then:

and the main training loop is:

Training takes a while on my machine. The evolution of the losses vs. training epoch looks as follows:

We can then evaluate our trained model on the test set:

Let’s look at the difference between the true variances used for generation of the Gaussian distributions and the predicted variances with the following code:

At the same time we compare the maximum likelihood and unbiased estimators of the variances to the true variances. This table summarizes the results for this example run:

The root of the mean square error is comparable (if not slightly better — keep in mind however that this is only for this sample and not for the general case) to the maximum likelihood and unbiased estimators.

In conclusion, the network structure described above can be used to learn functions of a variable number of inputs.

(the underlying notebook for this article can be found here)