Semi-supervised learning from side information

Consider a learning task where generating labels is prohibitively expensive, but where it is possible to gather a number of auxiliary signals that are informative about the true label. These auxiliary signals are not available at test time.

That is, there is a sequence of increasingly expensive types of labels, z¹, z², …, zⁿ = y. Write x = z⁰ for the input features.

Many or all data points are labelled with z¹, fewer are labelled with both z¹ and z², and so on, down to a very small number labelled with every type of label (including y). Our goal is to predict y given x.

I think that this setting is quite common, and that it should be of broad interest to AI researchers. But I’m not aware of much work in this setting, and in particular I’ve never seen the algorithm in this post despite its simplicity. (I welcome pointers to the existing literature that I’ve missed!)

This problem is especially relevant to AI control, since it enables use of expensive human oversight administered with low probability. For example, good solutions play an important role in my view that we can probably build zero-overhead aligned AI systems.

Examples

  1. x is a voice search query. y is an accurate human transcription of the query. z¹ is the text query the user made immediately after the voice query if any, or “null” otherwise. z² is a transcription produced by a mechanical turker without strong accuracy checks.
  2. x is a question. The zⁱ are a sequence of answers to that question generated after increasingly lengthy reflection.
  3. x is the observation/state of a virtual agent, and a proposed action in that state. y is a human’s evaluation of how good that action is. z¹ is the cumulative reward obtained after taking that action, where the reward function is a proxy for human approval.
  4. x is the state of a Go game. y is the outcome of playing out a full game from that state. z¹ is the result of rollouts with a very low cost policy. z² is the outcome of a game played by an intermediate-quality policy network starting from the given state. z³ is a human judgment about the quality of a board position.

Assumptions/elaborations

We assume that the presence of a particular feature is unrelated to the label — features might have a “null” value, but a feature being present with a null value is different from being absent because we didn’t bother to collect it.

I’m going to consider the case where the zⁱ are sparse, but we could also ask a more general question where we are simply interested in variance reduction. I think the more general question is extremely interesting, and that similar techniques will be applicable.

A simple algorithm

Consider the following algorithm:

  • Let fⁿ = zⁿ.
  • For each i, train a predictor fⁱ to predict fⁱ⁺¹ given z⁰, z¹, …, zⁱ. This only occurs on datapoints where all of the relevant data is available. The loss is Σ fⁱ⁺¹(y) log( fⁱ(y) ).
  • Use f¹ as the classifier.

This is essentially TD learning without the time, and with a logarithmic rather than quadratic loss.

If we are predicting a scalar, then we can go back to using a quadratic loss.

If the space of possible predictions is very large, such that we are essentially training a generative model, then we could instead use a generative adversarial net at each step (with fⁱ trying to sample from the distribution produced by fⁱ⁺¹). I won’t discuss this version more in the post, since it involves some extra complications.

In contrast with a natural generative method for leveraging the zⁱ, this approach does not bother trying to reproduce the labels themselves. We will see shortly that this lets us prove that having intermediate labels is strictly better than nothing.

An improvement

If each classifier was more powerful than the one before it, then using this procedure directly would seem sensible. But each network is actually trained on much less data than the one before it, and so must have lower capacity or be more aggressively regularized.

That means that we can’t count on (e.g.) f³ being strictly smarter than f². So f² might actually predict better by predicting the outcome directly rather than predicting f³.

We can get the best of both worlds, by using f³ as a variance reduction strategy for f².

That is, we train fⁱ on the sum of the following targets:

  • fⁱ⁺¹ for every data point where zⁱ⁺¹ is available.
  • ( fⁱ⁺² − fⁱ⁺¹)/ε for data points where zⁱ⁺² is available, where ε is the probability that zⁱ⁺² is available conditioned on zⁱ⁺¹ being available.
  • ( fⁱ⁺³ − fⁱ⁺²)/ε′ for data points where zⁱ⁺³ is available, where ε′ is the probability that zⁱ⁺³ is available conditioned on zⁱ⁺¹ being available.
  • And so on.

In the case where fⁱ⁺¹ is completely uninformative, this reduces to only training fⁱ on data points where zⁱ⁺² is available. In the case where fⁱ⁺¹ = y, this amounts to training fⁱ on all of the data. (If the predictor fⁱ is constant, this essentially reduces to output distribution matching.)

A claim

Including the intermediate signals zⁱ can only improve the quality of a naive prediction algorithm which only uses the labelled data.

To see this, we use the fact that the naive prediction algorithm is equivalent (modulo choice of learning rates) to our improved algorithm with fⁱ = 0 for all i.

But improving the quality of fⁱ, as a predictor, is guaranteed to reduce the variance of our training signal (while any change to f will leave our signal unbiased) — at least if we assume that fⁱ is unbiased and that it does not have higher variance on data points where zⁱ⁺¹ is unavailable. This is because we can write the variance of the target as the variance of an unbiased predictor plus that predictor’s MSE, and shifting more variance to the predictor decreases the total variance of our signal.

So using any reasonable training procedure for the fⁱ will result in a lower variance signal than simply setting them to 0.

Similarly, adding additional intermediate training signals can only help.

It’s not clear how much these additions help. Intuitively, there are many cases where they would significantly reduce the variance of the training signal. Reducing the variance of a training signal by a factor of k corresponds roughly to increasing the amount of data by a factor of k. But it’s hard to know how it will actually play out without trying it in some context.

More complex structures

I’ve assumed that zⁱ is available whenever zⁱ⁺¹ is. It would be nice to apply similar methods in cases where we don’t have such a simple linear order.

In this setting, for each set of indices T ⊆ {0, 1, 2, …, n}, we train a predictor fᵀ to predict y given the values zⁱ for i ∈ T. As before, we can use the predictors corresponding to big sets to do variance reduction for the predictors corresponding to small sets. Our classifier is of course the predictor for the subset T = {0}.

However, there are now many possible variance reduction schemes, each of which is quite complicated. I don’t have any particular insight for how to choose amongst them, or even whether they will work at all. It seems like an interesting question if applications with complex structure seem important at some point, but it’s probably a fine issue to set aside for now.