PyTorch Lightning: Metrics
With PyTorch Lightning 0.8.1 we added a feature that has been requested many times by our community: Metrics. This feature is designed to be used with PyTorch Lightning as well as with any other PyTorch based code. In this blog post, we’ll explain what Metrics is and how you can get started
Disclaimer: This Post is no longer the recommended way to use metrics. We redesigned it completely and outsourced it as a separate package. See this GitHub or this Medium link for reference.
That means, that the implementation differs now and more closely follows the idea of a torch.nn.Module than before, but the following chapters as of why a Metric is important is still relevant.
What’s a metric and why do I need this?
The purpose of metrics in general is to allow you to monitor and quantify the training process. If you are using more advanced techniques like learning-rate scheduling or early stopping, you are probably using metrics for this. While you could also use losses for this, metrics are preferred since they represent the training goal better. This means, that while the loss (like cross-entropy) pushes the network’s parameters into the correct direction, metrics can show some additional insights on network behavior.
In opposite to losses, they also don’t have to be differentiable at all(in fact many of them aren’t), but some of them are. If the metric itself is differentiable and it is implemented using pure PyTorch it can also be used to backpropagate over it and use it for whatever fancy research you want to do.
In the mathematical sense, a metric is defined as a distance between each pair of elements in a set. While this definition also holds for the deep-learning understanding of a metric, some of the mathematical constraints of a metric must not be fulfilled here, since we don’t need them to be symmetric in all cases and they also don’t necessarily have to hold the triangle inequality.
How to use it
Now we all have a basic understanding of metrics, let me explain how you can use them in PyTorch Lightning.
Getting started steps are very simple. There are two different ways to get started: a functional and a module based way. Let’s first have a look at the functional way, since this one is easier one.
Basically you only write your metrics as a function using only torch operations.
Okay, that was really easy, right? This is plain PyTorch code and this does nothing special. But this also worked before before we introduced the Metrics package. So why do we need this package at all?
Creating metrics like a function this way works but you need to manually make sure the inputs and outputs are all tensors, are all on correct devices and are all of the correct type. This is where PyTorch Lightning’s automation approach starts.
Take a look at the following example:
You can notice the
@tensor_metric() decorator. This actually converts all inputs to tensors and all outputs to tensors as well (in case you somehow change the type of result here). Additionally it makes sure to synchronize the Metric's output across all DDP nodes if DDP was initialized.
tensor_metric, there are two other decorators:
tensor_collection metric only converts all occurences of numbers and numpy arrays (to avoid errors due to the fact, that some collections (like lists of lists with different lengths) cannot be converted to tensors) but also syncs across DDP nodes,
numpy_metric converts all inputs to numpy arrays and all outputs to torch tensors. This enables you to basically use your favorite numpy code as a Metric as well and still get automated DDP syncing.
Note: We strongly recommend to use/write native tensor implementations whenever possible, since for numpy Metrics each call requires a GPU synchronization and thus may slow down your training substantially
Module Metrics Interface
The easiest way to provide a module interface for your Metric if you already have a functional interface is as follows:
You just derive a class from our base class and call your functional Metric within the
in this example
TensorMetric already deals with the DDP sync and input/output conversion.
There are 3 more metric base classes besides
NumpyMetric: wrapper for metric functions implemented with numpy
TensorCollectionMetric: wrapper for metric whose outputs cannot be converted to torch.Tensor's completely (like list of lists with different length)
Metric: The most basic class. Does not do DDP sync and no input/output conversion. This class should be used for functional metrics, that already handle conversion on their own (for example if they are decorated with
Now you may ask, what’s the advantage of these modular interfaces over functional interfaces?
The first one is quite obvious:
Metric is a class derived from
torch.nn.Module. That means, you also gain all the advantages from them like registering buffers whose device and dtype can be easily changed, by changing the metrics device or dtype. So you don't have to take care of this yourself.
The next one: We extended torch.nn.Module by a
device property. So you can always access on which device your tensors should be (and we also use this for automated device changes if necessary).
Third, In case you want a different ddp reduction than the standard one or you only want to reduce across a specific group of processes, you can just specify this upon initalization of these classes, while on functionals, the decorator specifies these arguments, meaning you cannot change them that easily.
Fourth: We plan to introduce additional things for these classes like automated metric specific aggregation across batches, automated metric evaluation etc, which will also come built-in with these classes.
This also brings me to my next and almost last point:
Future Plans with this package
Currently, all of these Metrics can be used within or without PyTorch Lightning. We will also make sure to keep it that way, BUT we will introduce some convenience features like the ones mentioned above, that probably will be deeply integrated with PyTorch Lightning, but not that easily usable without it.
Already Available Metrics
I spent a lot of time to explain how metrics work in PyTorch Lightning, but our aim with this package is to collect and consolidate common metrics with a single interface for more standardized usage and research. Therefore we also have support for all the SciKit-Learn Metrics.
Note: since these Metrics are implemented with numpy, they can also slow down your training substantially
We also started implementing a growing list of native Metrics like accuracy, auroc, average precision and about 20 others (as of today!).
You can see the documentation of the Metrics’ package here.
Whenever you write some fancy application and want to run it in a distributed fashion, you had to sync metrics/results manually. That’s over now! The module interface synchronises the Metric’s outputs across a specific process group with a specific operation (per default it’s sum). We are already working on a separate version that includes Metric-specific reductions.
Usage without PyTorch Lightning
As already mentioned, these metrics (even the already implemented ones!) can also be used without anything else from PyTorch Lightning!
Let’s have a look at this short example:
As you see, you can use it by just importing the Metric without any changes required!
At the very end, I also have one thing to ask you for: If you have an implementation of a Metric, we did not yet implement, please consider opening an issue and ping me (@justusschock) for discussion and (guidance on) implementation. I’ll try my best to answer as soon as possible to make sure, we all get those Metrics standardized.