Finding why Pytorch Lightning made my training 4x slower.

🤔 What happened?

Two weeks ago, I refactored some deep learning researcher’s code to , expecting approximately a 1.5x speedup. However, what I got was a 4x slowdown of the training, evaluation, and testing tasks. This neural network was supposed to run for days before getting any good result, and I needed to cut down training time as much as possible.

My name is , and this is the story of how I tracked down a major bug in Pytorch Lightning.

📚 Some context.

I was working with some open-source Deep Learning code. It was made by two researchers and demonstrated a new state-of-the-art architecture for some machine learning tasks.

However, as usual in research, the code itself was neither clean nor optimized. I noticed several places where a speedup was possible and refactored the code into proper Pytorch. My training was now ~3 times faster.

However, I was sure there was still room for improvement. In that area, I think is the best tool available: it removes a lot of boilerplate, and bundles several optimization methods like 16-bit precision or Stochastic Gradient Averaging. I decided to refactor my own code using Lightning.

This is not my first time doing so, and I was expecting approximately a ~1.5x speedup in my code. When I finished my refactoring, I was quite surprised to notice my iterations’ time went from 4 seconds to 15 seconds, making my training effectively 4 times slower.

Take your time, Lightning.

🔍 Getting some clues.

I started by running Lightning’s profilers to find where the problem was.

The basic profiler gave me a starting point: most of the time was spent on running an epoch. The advanced profiler didn’t bring more information, which was disappointing but not surprising. I decided to first look if the problem was on my side.

I wondered if I misconfigured some hyperparameter on my neural network. I messed with some of them, without noticing anything different. I then tweaked my Data Loaders, and I discovered that changing the number of jobs n_jobs had an impact on the total training time. However, instead of speeding up my calculations, it slowed them down.

Time for 100 epochs, depending on the number of jobs

Entirely disabling multiprocessing with n_jobs=0 made my iterations almost 2x faster than using 6 cores. By default, Pytorch kills & reloads workers between each epochs, causing the dataset to be reloaded.

In my case, loading the dataset was very slow. However, I had the persistent_workers parameter set to True on my DataLoader. This prevents workers from getting killed, and data from being reloaded.

# My data Loader parameters
DataLoader(
train_dataset, batch_size=64, shuffle=True, num_workers=n_workers,
persistent_workers=True, pin_memory=True,
)

Therefore, there were two possibilities:

  • Pytorch Lightning kills workers regardless of the persistent_workers parameter ;
  • The problem lied somewhere else.

I decided to open to make the Lightning team aware of my problem, and kept on searching its root cause.

🕵️ Finding the culprit

Lightning’s profiler works with context managers and calculates how much time a given block took. It makes it easy to search for a specific profiler’s action, run_training_epoch in our case.

I had my starting point. From there, I wandered into Lightning’s source code and started looking which instructions made my loops slower. After digging around, I found the problematic lines:

Loop.run calls Loop.on_run_start…
And Loop.on_run_start reloads the dataloader.

Bingo! It looks like the problem indeed comes from the DataLoader being reloaded at each epoch. I started to look at DataLoader source code, and found this:

When iterating through a DataLoader with persistent_workers > 0, if _iterator` is None, the entire dataset is reloaded using _get_iterator(). Pytorch Lightning must have been resetting _iterator by mistake, leading to the issue.

To confirm this theory, I replaced the DataLoader with a custom one overloading only the __iter__ method:

As expected, after an iteration the _iterator attribute was correctly set but was reset to None before the beginning of the next epoch.

n_jobs=1, persistent_workers=True

I just needed to know when the attribute was set to None to find the source of the problem. I tried using my debugger, but it crashed due to multiprocessing or CUDA. I resorted to the usage of Python’s getter & setter:

This will print a stack trace whenever DataLoader._iterator is set to None

It worked like a charm, and I got this output:

File "trainer\trainer.py", line 1314, in _run_train
self.fit_loop.run()
...
File "loops\fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "loops\base.py", line 139, in run
self.on_run_start(*args, **kwargs)
File "loops\epoch\training_epoch_loop.py", line 142, in on_run_start
self._dataloader_iter = _update_dataloader_iter(...)
File "loops\utilities.py", line 121, in _update_dataloader_iter
dataloader_iter = enumerate(data_fetcher, batch_idx)
File "utilities\fetching.py", line 198, in __iter__
self.reset()
File "utilities\fetching.py", line 212, in reset
self.dataloader.reset()
...
File "trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
dataloader._iterator = None

This stack trace shows that DataLoader.reset is called each time the run starts. After digging through the code, I found that the DataFetcher was reset on each iteration, leading the DataLoader to be reset too. There were no conditions that failed to pass and led to a reset: each epoch had to reset the DataLoader.

That was the root cause of my impressive slowdown.

⚙️ Fixing the bug.

Fixing the bug was actually pretty simple, and pretty dumb: I removed the self.reset line from DataFetcher’s __iter__ method.

I then ran my training, and… Tada! A single iteration now takes only 1.5 seconds, compared to 15s before and 3s with vanilla Pytorch. It’s a nice speedup!

💡 Outcome

After with the Lightning team, they resolved the issue and pushed a hotfix the day after. I updated the library, and their fix indeed works. Kudos to the team for their reactivity!

I know some people faced this bug, thanks to some Github discussions or some comments on my issue. I think a lot more will benefit from this fix, and will see improved training and testing times on their Lightning models. If you didn’t update your dependencies recently, try installing pytorch-lightning==1.5.1 or higher!

Thanks for reading my first article! I hope you found it interesting. Don’t hesitate to send me feedback, or to regarding AI.

--

--

Engineer, entrepreneur, loving AI and data in general.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store