Revisiting Classification

Davis Liang
TheLocalMinima
Published in
5 min readJan 23, 2019

Reconstructing classification as a regression task.

The Old Testament

In the context of deep learning, when tasked with performing classification, a practitioner may first take some input and subject it to many linear and nonlinear transformations. She may then take the result, whose cardinality is necessarily the number of classes in the dataset, and apply the softmax function, which offers a smooth approximation of the maximum over all classes.

prim & proper.

This process results in a classifier with several drawbacks:

  1. The resulting neural net requires a fine-tuning step and additional labeled data to generalize to new classes.
  2. As the number of classes grows, so does the computational cost of training the model.

A Semantic Alternative

Imagine if we took a dataset (say CIFAR-10) and we plotted the labels, as one-hot vectors, into N-D vector space. Pretty boring right? The labels for cat and dog are orthogonal (as are cat and plane) — in other words, such a schema offers no additional semantic information about the labels, other than the fact that they are different.

But what happens when we leverage word vectors and plot the skip-gram representations of these labels in semantic vector space? Looking at the first two principal components of these representations, we see obvious clusters begin to emerge: for CIFAR-10, animals are all represented on one side while vehicles are all represented in the other. Cats are closer to dogs than to trucks. From a sample of Caltech-256, we observe a cluster for plants, a cluster for musical instruments, and another cluster for electronic devices.

aha!

From here, the picture starts to slowly come together. If we can somehow use these semantic representations as the supervising signal for our classifier, then we effectively solve the issues discussed above. Specifically,

  1. By training on N class categories, e.g. dog, horse, frog, we can introduce a new category, e.g. cat, without any additional labeled training examples. We assume with enough trained categories, our network would be able to ‘learn’ the semantic vector space. More on this later.
  2. Even as the number of classes grows, the dimensionality of the semantic vector space stays the same.

The only question remaining is: how do we construct a setup that allows us to train and evaluate against semantic representations?

The New Testament

The solution is simple: regression and nearest neighbors.

  • During training, we regress against the semantic label vectors using an L2 loss function.
  • During inference, we have a bank of word vectors, allowing us to perform nearest neighbors in a distributed fashion.

In practice, this method of training and evaluating works quite well. We trained two small models: one using the classic multi-class classification approach and another using our semantic vector regression approach.

we do no worse!

Taking a look at the confusion matrix of CIFAR-10 (multi-class) vs. CIFAR-10 (semantic regression) we observe a similar trend: semantic regression generally does no worse (and in many cases better) than old-school multi-class classification.

One-hot seems a bit more confused.

Zero Shot Learning

We know what happens during normal training and evaluation. But what happens when we introduce a new class category? For example, what happens when we train a network on CIFAR-10 and test it on images from CIFAR-100? For vanilla multi-class classification, we know that our network would be wholly unable to perform the task — it lacks the output node to predict the new label. However, a semantic regression approach that is able to bootstrap from an existing repository of trained embeddings should do better.

We tested this approach by performing inference on never-before-seen categories, drawn from Caltech-256, on a model trained on Caltech-101. Here, we randomly sample 4 classes to evaluate zero shot learning.

Boom. It generalizes!

Interestingly enough, the results revealed a high generalization rate with over half of the added classes achieving an accuracy greater than random guessing. With a small fine-tuning step with a few images, this accuracy increases drastically without hurting performance on existing categories.

Nothing’s Perfect

But nothing’s perfect, right? Unsurprisingly, word vectors are not best representation for images. A picture of a zebra and a horse, while close in image space are drastically different in word vector space (we generally use horse and zebra in very different natural language contexts. E.g. horses are to be ridden and zebras are to… well… roam free in the African savanna). Perhaps the next step is to redo this experiment with proper semantic image representations (hidden layer representations, anyone?). Additionally, semantic regression requires semantic vectors to be learned beforehand.

With this said, the benefits of this approach are still quite clear.

Conclusion

From our casual approach, it seems that semantic regression not only improves overall accuracy, but additionally confers the ability to perform zero-shot learning. Furthermore, semantic regression allows us to perform evaluation in a distributed fashion — removing the Softmax bottleneck. Though not perfect, this approach is a clear alternative to vanilla multi-class classification. Thanks to my former graduate school colleagues Arvind Rao, Daniel Riley, Cuong Luong, and Gannon Gesiriech for helping run experiments and providing useful discussion.

--

--

Davis Liang
TheLocalMinima

Scientist at Amazon AI (@awscloud). Formerly (@carnegiemellon, @ucsandiego, @yahoo).