Finding why Pytorch Lightning made my training 4x slower.
🤔 What happened?
Two weeks ago, I refactored some deep learning researcher’s code to Pytorch Lightning, expecting approximately a 1.5x speedup. However, what I got was a 4x slowdown of the training, evaluation, and testing tasks. This neural network was supposed to run for days before getting any good result, and I needed to cut down training time as much as possible.
My name is Florian Ernst, and this is the story of how I tracked down a major bug in Pytorch Lightning.
📚 Some context.
I was working with some open-source Deep Learning code. It was made by two researchers and demonstrated a new state-of-the-art architecture for some machine learning tasks.
However, as usual in research, the code itself was neither clean nor optimized. I noticed several places where a speedup was possible and refactored the code into proper Pytorch. My training was now ~3 times faster.
However, I was sure there was still room for improvement. In that area, I think Pytorch Lightning is the best tool available: it removes a lot of boilerplate, and bundles several optimization methods like 16-bit precision or Stochastic Gradient Averaging. I decided to refactor my own code using Lightning.
This is not my first time doing so, and I was expecting approximately a ~1.5x speedup in my code. When I finished my refactoring, I was quite surprised to notice my iterations’ time went from 4 seconds to 15 seconds, making my training effectively 4 times slower.
🔍 Getting some clues.
I started by running Lightning’s profilers to find where the problem was.
The basic profiler gave me a starting point: most of the time was spent on running an epoch. The advanced profiler didn’t bring more information, which was disappointing but not surprising. I decided to first look if the problem was on my side.
I wondered if I misconfigured some hyperparameter on my neural network. I messed with some of them, without noticing anything different. I then tweaked my Data Loaders, and I discovered that changing the number of jobs n_jobs
had an impact on the total training time. However, instead of speeding up my calculations, it slowed them down.
Entirely disabling multiprocessing with n_jobs=0
made my iterations almost 2x faster than using 6 cores. By default, Pytorch kills & reloads workers between each epochs, causing the dataset to be reloaded.
In my case, loading the dataset was very slow. However, I had the persistent_workers
parameter set to True on my DataLoader. This prevents workers from getting killed, and data from being reloaded.
# My data Loader parameters
DataLoader(
train_dataset, batch_size=64, shuffle=True, num_workers=n_workers,
persistent_workers=True, pin_memory=True,
)
Therefore, there were two possibilities:
- Pytorch Lightning kills workers regardless of the
persistent_workers
parameter ; - The problem lied somewhere else.
I decided to open an issue on Github to make the Lightning team aware of my problem, and kept on searching its root cause.
🕵️ Finding the culprit
Lightning’s profiler works with context managers and calculates how much time a given block took. It makes it easy to search for a specific profiler’s action, run_training_epoch
in our case.
I had my starting point. From there, I wandered into Lightning’s source code and started looking which instructions made my loops slower. After digging around, I found the problematic lines:
Bingo! It looks like the problem indeed comes from the DataLoader being reloaded at each epoch. I started to look at DataLoader
source code, and found this:
When iterating through a DataLoader with persistent_workers > 0
, if _iterator`
is None, the entire dataset is reloaded using _get_iterator()
. Pytorch Lightning must have been resetting _iterator
by mistake, leading to the issue.
To confirm this theory, I replaced the DataLoader with a custom one overloading only the __iter__
method:
As expected, after an iteration the _iterator
attribute was correctly set but was reset to None
before the beginning of the next epoch.
I just needed to know when the attribute was set to None to find the source of the problem. I tried using my debugger, but it crashed due to multiprocessing or CUDA. I resorted to the usage of Python’s getter & setter:
It worked like a charm, and I got this output:
File "trainer\trainer.py", line 1314, in _run_train
self.fit_loop.run()
...
File "loops\fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "loops\base.py", line 139, in run
self.on_run_start(*args, **kwargs)
File "loops\epoch\training_epoch_loop.py", line 142, in on_run_start
self._dataloader_iter = _update_dataloader_iter(...)
File "loops\utilities.py", line 121, in _update_dataloader_iter
dataloader_iter = enumerate(data_fetcher, batch_idx)
File "utilities\fetching.py", line 198, in __iter__
self.reset()
File "utilities\fetching.py", line 212, in reset
self.dataloader.reset()
...
File "trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
dataloader._iterator = None
This stack trace shows that DataLoader.reset
is called each time the run starts. After digging through the code, I found that the DataFetcher was reset on each iteration, leading the DataLoader to be reset too. There were no conditions that failed to pass and led to a reset: each epoch had to reset the DataLoader.
That was the root cause of my impressive slowdown.
⚙️ Fixing the bug.
Fixing the bug was actually pretty simple, and pretty dumb: I removed the self.reset
line from DataFetcher’s __iter__
method.
I then ran my training, and… Tada! A single iteration now takes only 1.5 seconds, compared to 15s before and 3s with vanilla Pytorch. It’s a nice speedup!
💡 Outcome
After sharing my findings with the Lightning team, they resolved the issue and pushed a hotfix the day after. I updated the library, and their fix indeed works. Kudos to the team for their reactivity!
I know some people faced this bug, thanks to some Github discussions or some comments on my issue. I think a lot more will benefit from this fix, and will see improved training and testing times on their Lightning models. If you didn’t update your dependencies recently, try installing pytorch-lightning==1.5.1
or higher!
Thanks for reading my first article! I hope you found it interesting. Don’t hesitate to send me feedback, or to contact me regarding AI.