PyTorch vs TensorFlow — spotting the difference
In this post I want to explore some of the key similarities and differences between two popular deep learning frameworks: PyTorch and TensorFlow. Why those two and not the others? There are many deep learning frameworks and many of them are viable tools, I chose those two just because I was interested in comparing them specifically.
TensorFlow is developed by Google Brain and actively used at Google both for research and production needs. Its closed-source predecessor is called DistBelief.
PyTorch is a cousin of lua-based Torch framework which is actively used at Facebook. However, PyTorch is not a simple set of wrappers to support popular language, it was rewritten and tailored to be fast and feel native.
The best way to compare two frameworks is to code something up in both of them. I’ve written a companion jupyter notebook for this post and you can get it here. All code will be provided in the post too.
First, let’s code a simple approximator for the following function in both frameworks:
We will try to find unknown parameter phi given data x and function values f(x). Yes, using stochastic gradient descent for this is an overkill and analytical solution may be found easily, but this problem will serve our purpose well as a simple example.
We will solve this with PyTorch first:
If you have some experience in deep learning frameworks you may have noticed that we are implementing gradient descent by hand. Not very convenient, huh? Gladly, PyTorch has
optimize module which contains implementations of popular optimization algorithms such as RMSProp or Adam. We will use SGD with momentum
As you can see, we quickly inferred true exponent from training data. And now let’s go on with TensorFlow:
As you can see, implementation in TensorFlow works too (surprisingly 🙃). It took more iterations to recover the exponent, but I am sure that the cause is I did not fiddle with optimiser’s parameters enough to reach comparable results.
Now we are ready to explore some differences.
Difference #0 — adoption
Currently, TensorFlow is considered as a to-go tool by many researchers and industry professionals. The framework is well documented and if the documentation will not suffice there are many extremely well-written tutorials on the internet. You can find hundreds of implemented and trained models on github, start here.
PyTorch is relatievly new compared to its competitor (and is still in beta), but it is quickly getting its momentum. Documentation and official tutorials are also nice. PyTorch also include several implementations of popular computer vision architectures which are super-easy to use.
Difference #1 — dynamic vs static graph definition
Both frameworks operate on tensors and view any model as a directed acyclic graph (DAG), but they differ drastically on how you can define them.
TensorFlow follows ‘data as code and code is data’ idiom. In TensorFlow you define graph statically before a model can run. All communication with outer world is performed via
tf.Session object and
tf.Placeholder which are tensors that will be substituted by external data at runtime.
In PyTorch things are way more imperative and dynamic: you can define, change and execute nodes as you go, no special session interfaces or placeholders. Overall, the framework is more tightly integrated with Python language and feels more native most of the times. When you write in TensorFlow sometimes you feel that your model is behind a brick wall with several tiny holes to communicate over. Anyways, this still sounds like a matter of taste more or less.
However, those approaches differ not only in a software engineering perspective: there are several dynamic neural network architectures that can benefit from the dynamic approach. Recall RNNs: with static graphs, the input sequence length will stay constant. This means that if you develop a sentiment analysis model for English sentences you must fix the sentence length to some maximum value and pad all smaller sequences with zeros. Not too convenient, huh. And you will get more problems in the domain of recursive RNNs and tree-RNNs. Currently Tensorflow has limited support for dynamic inputs via Tensorflow Fold. PyTorch has it by-default.
Difference #2 — Debugging
Since computation graph in PyTorch is defined at runtime you can use our favorite Python debugging tools such as pdb, ipdb, PyCharm debugger or old trusty print statements.
This is not the case with TensorFlow. You have an option to use a special tool called tfdbg which allows to evaluate tensorflow expressions at runtime and browse all tensors and operations in session scope. Of course, you won’t be able to debug any python code with it, so it will be necessary to use pdb separately.
Difference #3 — Visualization
Tensorboard is awesome when it comes to visualization 😎. This tool comes with TensorFlow and it is very useful for debugging and comparison of different training runs. For example, consider you trained a model, then tuned some hyperparameters and trained it again. Both runs can be displayed at Tensorboard simultaneously to indicate possible differences. Tensorboard can:
- Display model graph
- Plot scalar variables
- Visualize distributions and histograms
- Visualize images
- Visualize embeddings
- Play audio
Tensorboard can display various summaries which can be collected via
tf.summary module. We will define summary operations for our toy exponent example and use
tf.summary.FileWriter to save them to disk.
To launch Tensorboard execute
tensorboard --logdir=./tensorboard. This tool is very convenient to use on cloud instances since it is a webapp.
Difference #4 — Deployment
If we start talking about deployment TensorFlow is a clear winner for now: is has TensorFlow Serving which is a framework to deploy your models on a specialized gRPC server. Mobile is also supported.
When we switch back to PyTorch we may use Flask or another alternative to code up a REST API on top of the model. This could be done with TensorFlow models as well if gRPC is not a good match for your usecase. However, TensorFlow Serving may be a better option if performance is a concern.
Tensorflow also supports distributed training which PyTorch lacks for now.
Difference #5 — A Framework or a library
Let’s build a CNN classifier for handwritten digits. Now PyTorch will really start to look like a framework. Recall that a programming framework gives us useful abstractions in certain domain and a convenient way to use them to solve concrete problems. That is the essence that separates a framework from a library.
Here we introduce
datasets module which contains wrappers for popular datasets used to benchmark deep learning architectures. Also
nn.Module is used to build a custom convolutional neural network classifier.
nn.Module is a building block PyTorch gives us to create complex deep learning architectures. There are large amounts of ready to use modules in
torch.nn package that we can use as a base for our model. Notice how PyTorch uses object oriented approach to define basic building blocks and give us some 'rails' to move on while providing ability to extend functionality via subclassing.
Here goes a slightly modified verson of https://github.com/pytorch/examples/blob/master/mnist/main.py:
Plain TensorFlow feels a lot more like a library rather than a framework: all operations are pretty low-level and you will need to write lots of boilerplate code even when you might not want to (let’s define those biases and weights again and again and …).
As the time as passed a whole ecosystem of high-level wrappers started to emerge around TensorFlow. Each of those aims to simplify the way you work with the library. Many of them are currently located at
tensorflow.contrib module (which is not considered a stable API) and some started to migrate to the main repository (see
So, you have a lot of freedom on how to use TensorFlow and what framework will suit the task best: TFLearn, tf.contrib.learn, Sonnet, Keras, plain
tf.layers, etc. To be honest, Keras deserves another post but is currently out of the scope of this comparison.
Here we will use
tf.contrib.learn to build our CNN classifier. The code follows the official tutorial on tf.layers:
So, both TensorFlow and PyTorch provide useful abstractions to reduce amounts of boilerplate code and speed up model development. The main difference between them is that PyTorch may feel more “pythonic” and has an object-oriented approach while TensorFlow has several options from wich you may choose.
Personally, I consider PyTorch to be more clear and developer-friendly. It’s
torch.nn.Module gives you the ability to define reusable modules in an OOP manner and I find this approach very flexible and powerful. Later you can compose all kind of modules via
torch.nn.Sequential (hi Keras ✋🏻). Also, you have all built-in modules in a fuctional form, which can be very convenient. Overall, all parts of the API play well together.
Of course, you can write very clean code in plain TensorFlow but it just takes more skill and trial-and-error before you get it. When it goes to higher-level frameworks such as Keras or TFLearn get ready to loose at least some of the flexibility TensorFlow has to offer.
TensorFlow is very powerful and mature deep learning library with strong visualization capabilities and several options to use for high-level model development. It has production-ready deployment options and support for mobile platforms. TensorFlow is a good option if you:
- Develop models for production
- Develop models which need to be deployed on mobile platforms
- Want good community support and comprehensive documentation
- Want rich learning resources in various forms (TensorFlow has entire an MOOC)
- Want or need to use Tensorboard
- Need to use large-scale distributed model training
PyTorch is still a young framework which is getting momentum fast. You may find it a good fit if you:
- Do research or your production non-functional requirements are not very demanding
- Want better development and debugging experience
- Love all things Pythonic
If you have the time the best advice would be to try both and see what fits your needs best.
If you liked this article, please leave a few 👏. It lets me know that I am helping.