Dealing with memory leak issue in Keras model training
Recently, I was trying to train my keras (v2.4.3) model with tensorflow-gpu (v2.2.0) backend on NVIDIA’s Tesla V100-DGXS-32GB. When trained for large number of epochs, it was observed that there was memory build-up / leakage. What this meant was, as the training progressed, it was consuming more and more disk space until none was left, crashing the job or system.
One look over the internet and it was clear that, this problem has been around for sometime now. Some users, linked the issue to model.predict(), which I had included in my callbacks. In the same discussion, a suggested solution was:
- Instead of passing a np.array to the model.predict(), pass a tensor by using tf.convert_to_tensor(). The associated explanation mentions that,
for loop with a numpy input creates a new graph every iteration because the numpy array is created with a different signature. Converting the numpy array to a tensor maintains the same signature and avoids creating new graphs.
- Going by the explanation from the solution above, another proposed solution I could find was replacing model.predict() with model.predict_on_batch().
- I also tried cloning the trained model using keras.models.clone(model), and use the cloned model, as in, cloned_model.predict(). After the predict step, I’d delete the cloned model, hoping it will handle this memory build up.
Unfortunately, all the above strategies failed to solve the problem for me, even though it seems like it the first couple had worked for others. The solution that ultimately worked,
- Include gc.collect() and keras.backend.clear_session() after the model.predict().
With this, I could finally train the model with a constant disk usage.
The solution worked till the training was done with a single worker. With multiple workers, the problem got even worse (memory built up much faster) than that was observed originally.
Now, since I had previously attempted numerous solutions relating to model.predict() call, it was likely that the problem was from within model.fit(). To verify this I removed the callback with model.predict() call and true to my hypothesis, the issue persisted.
On further research, I discovered that the issue originates because of different handling of validation dataset generator. In the link, the issue is observed with tf.data.Dataset.from_generator() whereas I was using ImageDataGenerator().flow_from_dataframe() based generator. So, it seems the issue is independent of the data generator type.
After trying out permutations and combinations of all the discussed solutions, the only way out I could come up was to move evaluation operation into a callback, followed by gc.collect() and keras.backend.clear_session().
This solution for some reason also sped up the training epoch by around 30%.
Perhaps, one possible issue with the above solution
- Tensorboard visualization of validation and training loss gets separated
However, I feel it is a small price worth paying for the massive speed gains possible with multiple workers.
Please do suggest if you have a better solution to this whole issue.
If you find stories like these valuable and would like to support me as a writer, please consider following me or signing up for Medium membership.