Keras shoot-out: TensorFlow vs MXNet

A few months, we took an early look at running Keras with Apache MXNet as its backend. Things were pretty beta at the time, but a lot of progress has since been made. It’s time to reevaluate… and benchmark MXNet against Tensorflow.

In this world, there’s two kinds of people, my friend. Those with GPUs and those who wait for days. You wait.

The story so far

The good folks at DMLC have forked Keras 1.2 in order to implement MXNet support, multi-GPU included. In parallel, they’ve moved the projet to the Apache Incubator and are currently putting the finishing touches to MXNet 0.11. This is pretty impressive work in such a short time frame!

In addition to the Keras and MXNet codebases, here’s what we’re going to use today:

Let’s ride.

Installing MXNet and Keras

Once the instance is running, we first have to update MXNet to the latest version (0.11.0-rc3 at the time of writing). Here, we’re obviously going for GPU support.

Updating Keras is quite simple too.

Let’s check that we have the correct versions.

Ok, looks good. Let’s move on to training.

Keras backends

Keras supports multiple backends for training and it’s very easy to switch from one to the other. Here are the two file versions for Tensorflow and MXNet.

All it takes is one line in the ~/.keras/keras.json file.

Learning CIFAR-10 with Tensorflow

Keras provides plenty of nice examples in ~/keras/examples. We can use cifar10_resnet50.py pretty much as is. Since we’re going to be using all 8 GPUs, let’s just update the batch size to 256, the number of epochs to 100 and disable data augmentation.

Time to train.

Here’s what memory usage looks like, as reported by nvidia-smi.

As we can see, TensorFlow is a bit of a memory hog, pretty much eating up 100% of available GPU memory . Not really a problem here, but I’m wondering if a much more complex model would still be able to fit in memory. To be tested in a future post, I suppose :)

After a while, here’s the result (full log here).

All right. Now let’s move on to MXNet.

Learning CIFAR-10 with MXNet

At the moment, auto-detection of GPUs is not implemented for MXNet in Keras, so we need to pass the list of available GPUs to the compile() API

Just replace the call to model.compile() in cifar10_resnet.py with this snippet.

Time to train.

Holy moly! MXNet is 60% faster: 25 seconds per epoch instead of 61. Very nice. In the same time frame, this would definitely allow us to try more things, like different model architectures or different hyper parameters. Definitely an advantage when you’re experimenting.

What about memory usage? As we can see, MXNet uses over 90% less RAM and there is plenty left for other jobs.

Here’s the result after 100 epochs (full log here): 43 minutes, 99.4% training accuracy, 62% test accuracy.

Conclusion

Granted, this is a single example and no hasty conclusion should be drawn. Still, with 8 GPUs and a well-known data set, MXNet is significantly faster, much more memory-efficient and more accurate than Tensorflow.

It seems to me every Deep Learning practitioner ought to check MXNet out, especially now that it’s properly integrated with Keras: changing a line of configuration is all it takes :)

If you’d like to dive a bit more into MXNet, may I recommend the following resources?

In part 2, I’m taking a deeper look at memory usage in Tensorflow and how to optimise it.

In part 3, we’ll learn how to fine-tune the models for improved accuracy.

Thank you for reading :)


This post was written while blasting classics by Whitesnake, Rainbow and Dio. Fortunately, no neighbour was injured in the process.