TabNet —Deep Neural Net for Tabular

Jihwan
5 min readJun 17, 2023

--

Deep Neural Networks (DNNs) are really good at training images and text, but These network were not goot at training on tabular data. Why is that? Well, each piece of information in a table (or ‘feature’) has its own special meaning. This is different from images or text, where DNNs can learn from relational patterns.

Most people use tree models for tabular data because they can handle these features well by defining decision boundary. However, as we get more and more data (especially data without labels), tree models are starting to fall short.

To solve this problem, researchers have been trying to make DNNs work better with tabular data. This paper introduces a new network called TabNet. This algorithm uses a special type of DNN to learn in a way similar to tree models. The great thing about TabNet is that it can figure out which features are important and learn to represent data better.

Introduction

TabNet employs a tree-based learning approach for its training process. This method doesn’t just enhance the model’s performance, but it also allows us to interpret the model’s results by calculating feature importance.

Being a DNN-based model, TabNet can perform representation learning on unlabeled data. Moreover, even when the data size becomes large, the training proceeds smoothly by solving it through mini-batches using the Stochastic Gradient Descent method.

In essence, TabNet is an algorithm that mixes the benefits of both tree learning and DNN learning algorithms. Its performance is so impressive that it’s been recognized as a top-tier solution on platforms like Kaggle.

How TabNet works as Decision-tree-like

1. Learning after Data Dimension Reduction through Feature Selection

Just like how a Decision Tree selects a feature at each node and then branches the node by categorizing the feature, thereby reducing the dimensions as the learning progresses, TabNet also pre-selects the features to be used in learning. By only using these features, it can effectively find the decision-boundary on the manifold, which is particularly suitable for tabular data with sparse characteristics.

2. TabNet’s Encoder Learns through Feedback from Previous Results, Similar to Tree Ensemble’s Weak Learner

In a tree ensemble, the weak learner transfers the error value of the previous tree to the next tree, applying more weight to the incorrectly predicted data and thereby refining the learning process. TabNet’s Encoder operates similarly, using the result value of the previous Encoder as feedback to update the feature masking to be used in the next Encoder. This structure is, in essence, an ensemble of Encoders, mirroring the ensemble of trees in a tree-based model.

Deep dive into TabNet

TabNet minimizes sparsity to apply an inductive bias that is favored for tabular data. By forming different sets of features for training, it successfully introduces diversity into the learning process for tabular data.

feature selection for TabNet training

The selection of this feature set is informed by feedback from previous processing results, thereby prioritizing and focusing on the parts that are deemed most important. This approach allows TabNet to effectively learn and adapt to the complex and diverse patterns inherent in tabular data.

One characteristic of tabular data is its sparsity. This can pose challenges when training models, as these sparse data points can sometimes act as outliers, skewing the model’s learning process. By applying a technique known as masking to reduce the data’s dimensionality, TabNet effectively addresses these sparsity issues. This is because all the selected features undergo a linear transformation, thereby improving the model’s ability to define the decision-boundary. This approach helps ensure that the model is robust and better equipped to handle the inherent complexities of tabular data.

The whole process of TabNet Classification

In TabNet, the Encoder process employs two transformers and a mask function. The output from the feature transformer is split into two parts. These parts are then used as inputs for the fully connected layer (FC) and the attentive transformer, respectively.

The FC layer takes all of the split feature transformer output by concatenating them. The attentive transformer, on the other hand, uses its portion of the split output to generate mask values, which are then used to select the most relevant or attentive feature sets.

By aggregating these mask values, the training process can prioritize certain features, distinguishing them based on their importance. This mechanism ensures that increases the model’s performance and interpretability.

In TabNet, the feature transformer is trained with the features that were filtered through masking. As the name ‘transformer’ suggests, the feature transformer’s role is to convert input features into an embedding space.

The embedding spaces constructed in all Encoders are then combined and passed through a final fully connected layer (FC) before carrying out classification.

The attentive transformer constructs the mask to be used in the next Encoder. The attention mechanism assigns weights to represent the importance of each part of the input, assisting the model to focus on the essential parts of the input data. In this way, the next Encoder to be trained selects the parts of the feature to focus on based on the feedback generated from the embedding space of the previous Encoder. It then reconfigures the mask, enabling the process of feature selection to proceed.

Experiment

TabNet showed superior performance compared to other methods, including tree-based models. Its effectiveness has been widely recognized, as evidenced by its use in numerous Kaggle competitions where it consistently achieved high performance.

Conclusion

The authors introduce a deep neural network algorithm inspired by tree-based learning that aligns well with the unique characteristics of tabular data. By combining the strengths of both tree learning and neural networks, this method allows for the application of diverse learning algorithms, such as representation learning and meta-learning, to tabular data. This integration opens the door to a new era of tabular data learning, offering fresh perspectives and approaches to better understand and leverage this data type.

Reference

https://openreview.net/attachment?id=BylRkAEKDH&name=original_pdf

https://www.kaggle.com/c/osic-pulmonary-fibrosis-progression/discussion/189496

https://www.kaggle.com/code/nyanpn/1st-place-public-2nd-place-solution/notebook#Inference

--

--

Jihwan

Interested in Machine Learning, Data Science, AI. I enjoy learning and applying new knowledge.