Why Tree-Based Models Beat Deep Learning on Tabular Data

A much-needed reality check for AI Researchers and Engineers caught up in the hype around Deep Learning

Devansh
Geek Culture
8 min readAug 27, 2022

--

Join 31K+ AI People keeping in touch with the most important ideas in Machine Learning through my free newsletter over here

With all the hype around Deep Learning and the new 100-Million Parameter models, it can be easy to forget that these large neural networks are just tools, and they have all their biases and weaknesses. One of the ideas that I stress through my content is that you should have a strong base of diverse skill-sets so that you can solve problems in an effective and efficient manner.

In this article, I will be breaking down the paper- Why do tree-based models still outperform deep learning on tabular data? The paper explains a phenomenon observed by Machine Learning Practitioners all over the world working in all kinds of domains- Tree Based models (like Random Forests), have been much better than Deep Learning/Neural Networks when it comes to analyzing tabular data. I will be sharing their findings to help you understand why this happens, and how you can use these lessons to create the best AI pipelines to handle the challenges you come across.

Don’t show this graph to all the people with ‘Deep Learning Expert|Podcaster|Blockchain|Software’ in their bio. They will probably start screeching and get violent.

Points to note about the Paper

Before we start looking at the discoveries made by the paper, we should first understand some important aspects of the paper. This will help us contextualize the findings and better evaluate the results. Too many people skip straight to the results and don’t take enough time to evaluate the context. This is a fatal sin, and if you do this, I will stop loving you.

One thing that stood out to me was that the paper had a lot of preprocessing. Some like removing Missing Data will handicap Tree Performance. As I’ve covered in this article- How to handle missing environmental data, Random Forests are very good for situations with missing data. I used them a lot when I was working with Johns Hopkins University to build a system to predict how changing health-system policy would affect public health. The data was extremely noisy, with tons of features and dimensions. The robustness and benefits of RF made them better than more ‘advanced’ solutions, which would break very easily.

Most of this is pretty standard stuff. I’m personally not a huge fan of applying too many preprocessing techniques because it can lead to you losing a lot of nuance about your dataset, but the steps taken here would produce datasets that are similar to the ones found when working. However keep these constraints in mind when evaluating your final results, because they will matter. If your datasets look very different, then take these results with a pinch of salt.

They also used random search for hyperparameter tuning. This is also industry standard, but in my experience, Bayesian Search is much better for sweeping through more extensive search spaces. I’ll make a video on it soon, so make sure you’re following my YouTube channel to stay updated with it. The link to that (and all my other work) will be at the end of this article.

With that out of the way, time to answer the main question that you clicked this article- Why do Tree-Based Methods beat Deep Learning?

Reason 1: Neural Nets are biased to overly smooth solutions

This was the first reason that the authors shared that Deep Learning Neural Networks couldn’t compete with Random Forests. Simply put, when it comes to non-smooth functions/decision boundaries, Neural Networks struggle to create the best-fit functions. Random Forests do much better with weird/jagged/irregular patterns.

If I had to guess why, one possible reason could be the use of a gradient in Neural Networks. Gradients rely on differentiable search spaces, which are by definition smooth. Pointy, broken, and random functions can’t be differentiated. This is one of the reasons that I recommend learning about AI concepts like Evolutionary Algorithms, traditional searches, and more basic concepts, that can be used for great results in a variety of situations when NNs fail.

For a more concrete example of the difference in Decision Boundaries between the tree-based methods(RandomForests) and Deep Learners take a look at the image below-

The better performance of RFs can be attributed to the more precise decision boundaries they generate.

In the Appendix, the authors had the following statement wrt to the above visualization

In this part, we can see that the RandomForest is able to learn irregular patterns on the x-axis (which corresponds to the date feature) that the MLP does not learn. We show this difference for default hyperparameters but it seems to us that this is a typical behavior of neural networks, and it is actually hard, albeit not impossible, to find hyperparameters to successfully learn these patterns.

This is obviously really important. This becomes even more remarkable when you realize that Tree-Based methods have much lower tuning costs, making them much better when it comes to bang-for-buck solutions.

Finding 2: Uninformative features affect more MLP-like NNs

Another huge factor, especially for those of you that work with giant datasets that encode multiple relationships at once. If you’re feeding irrelevant features to your Neural Network, the results will be terrible (and you will waste a lot more resources training your models). This is why spending a lot of time on EDA/Domain Exploration is so important. This will help understand the features, and ensure that everything runs smoothly.

The authors of the paper test the model performances when adding (random)and removing useless (more correctly-less important)features. Based on their results two interesting things showed up-

  1. Removing a lot of features reduced the performance gap between the models. This clearly implies that a big advantage of Trees is their ability to stay insulated from the effects of worse features.
  2. Adding random features to the dataset shows us a much sharper decline in the networks than in the tree-based methods. ResNet especially gets hammered by these useless features. I’m assuming the attention mechanism in the transformer protects it to some degree.
Tree Supremacy. One thing to note is that they used only the Random Forest feature importance. Involving more protocols to create a better feature accuracy score would make things much better.

A possible explanation for this phenomenon might just be in the way Decision Trees are designed. Anyone who has taken an intro to AI class will know about the concepts of Information Gain and Entropy in Decision Trees. These allow Decision Trees to pick the best Paths going forward by comparing the remaining features to pick the one that would allow for the best choices. To those not familiar with the concept (or RFs), I would suggest watching StatQuests videos on these concepts. I’m linking his guide to RandomForests here.

Getting back to the point, there is one final thing that makes RFs better performers than NNs when it comes to tabular data. That is rotational invariance.

Finding 3: NNs are invariant to rotation. Actual Data is not

Neural Networks are invariant to rotation. That means if you rotate the dataset, it will not change their performance. After rotating the datasets, the performance ranking of different learners flips, with ResNets (which were the worst), coming out on top. They maintain their original performance, while all other learners actually lose quite a bit of performance.

This is pretty interesting, but I have to learn more about it. Specifically, what does rotating datasets actually mean? I looked through the paper, but couldn’t find the details. I have reached out to the authors and will write a follow-up. Seeing some examples of rotated datasets would help me understand the implications of this finding better. If any of you have any ideas, share them with me in the comments/through my links.

Meanwhile, let’s look into why rotational variance is important. According to the authors, taking linear combinations of features (which is what makes ResNets invariant) might actually misrepresent features and their relationships.

…there is a natural basis (here, the original basis) which encodes best data-biases, and which can not be recovered by models invariant to rotations which potentially mixes features with very different statistical properties

Based on the performance drops, this is clearly a very important factor that needs to be considered. Going forward, I can see a lot of value in investigating the best data orientations. But I want to learn more about this before making any real comments on this. I’ve spent the last 4 days trying to learn about this, and so far (just like Jon Snow), I know nothing. For now, I’ll end things here.

If you’re looking to get into ML, this article gives you a step-by-step plan to develop proficiency in Machine Learning. It uses FREE resources. Unlike the other boot camps/courses, this plan will help you develop your foundational skills and set yourself up for long-term success in the field.

For Machine Learning a base in Software Engineering, Math, and Computer Science is crucial. It will help you conceptualize, build, and optimize your ML. My daily newsletter, Technology Interviews Made Simple covers topics in Algorithm Design, Math, Recent Events in Tech, Software Engineering, and much more to make you a better developer. I am currently running a 20% discount for a WHOLE YEAR, so make sure to check it out.

I created Technology Interviews Made Simple using new techniques discovered through tutoring multiple people into top tech firms. The newsletter is designed to help you succeed, saving you from hours wasted on the Leetcode grind. I have a 100% satisfaction policy, so you can try it out at no risk to you. You can read the FAQs and find out more here

Feel free to reach out if you have any interesting jobs/projects/ideas for me as well. Always happy to hear you out.

Reach out to me

Use the links below to check out my other content, learn more about tutoring, or just to say hi.

Free Weekly Summary of the important updates in Machine Learning(sponsored)- https://lnkd.in/gCFTuivn

Check out my other articles on Medium. : https://rb.gy/zn1aiu

My YouTube: https://rb.gy/88iwdd

Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y

My Instagram: https://rb.gy/gmvuy9

My Twitter: https://twitter.com/Machine01776819

If you’re preparing for coding/technical interviews: https://codinginterviewsmadesimple.substack.com/

Get a free stock on Robinhood: https://join.robinhood.com/fnud75

--

--

Devansh
Geek Culture

Writing about AI, Math, the Tech Industry and whatever else interests me. Join my cult to gain inner peace and to support my crippling chocolate milk addiction