Introduction to deep tabular models

Deep learning shows promise for improving tabular model performance

Capital One Tech
Capital One Tech
6 min readJul 27, 2023

--

abstract illustration

Deep learning models for tabular data

Academia, and industry researchers and practitioners view deep learning as the dominant methodology for machine learning in a wide range of domains from computer vision to natural language processing. While deep learning isn’t as prominent in the tabular domain, it has become an area of growing interest in recent years due to its ability to handle large datasets with a large number of features and complex interactions.

Historically, deep learning does not perform well on tabular data, especially when compared to strong baselines such as gradient-boosting tree models and random forests. That performance gap shrinks when applying a variety of new deep learning techniques.

Given that these are very early-stage deep learning techniques and less explored than computer vision, for example, the purpose of this article is to provide an overview of deep tabular models and explore how a deep learning technique called transfer learning can be used in the tabular domain to improve performance.

What are deep tabular models?

Deep tabular models are a class of neural networks that are designed to handle tabular data. These models consist of several layers of nonlinear transformations that allow them to capture complex patterns and dependencies in the data.

Deep tabular architecture

The architecture of a deep tabular model usually consists of an input layer, several hidden layers, and an output layer. The input layer is responsible for accepting the tabular data, while the hidden layers perform feature extraction and dimensionality reduction. The output layer produces the final predictions based on the learned representations.

TabNet and deep tabular models

One of the most popular deep tabular models is TabNet, available as a built-in algorithm on Google Cloud AI Platform Training. Introduced in 2019, TabNet inputs raw tabular data without any preprocessing and is trained using gradient descent-based optimization. It uses a machine learning technique called sequential attention to choose which model features to reason from at each decision step. Importantly, it can perform on par or better than other tabular learning models.

The advantages of deep learning

Deep learning offers unique advantages in the tabular domain, and motivation for advancing research on deep tabular models continues to grow. Since training these models from scratch can be computationally expensive and time-consuming, equipping practitioners with more ways to apply deep learning techniques can serve to mitigate these issues and help them train tabular models with their data more easily.

Transfer learning for deep tabular models

Transfer learning is a powerful technique in deep learning that enables models to transfer knowledge learned from one task to another. This technique has been widely used in image and natural language processing tasks and can be quite beneficial for learning from datasets with lots of data and transferring that knowledge to data-scarce settings. Because this approach requires differentiable models, a key property of neural networks, it is not possible in other types of machine learning, especially those very common in tabular domains.

Recently, researchers have figured out how to apply it to tabular data, where it has shown significant improvements in predictive performance.

This research study, funded by Capital One, demonstrates the strong performance of deep tabular models in bridging the gap between gradient-boosted decision trees (GBDT) and neural networks; exposing a major advantage of neural models in that they learn reusable features and are easily fine-tuned in new domains.

The components of transfer learning for deep tabular models

Transfer learning involves using a pre-trained model on a related task to improve the performance of a new task. In scenarios where you might have more data for one task than you do for another, this method can be especially useful. Transfer learning can be applied in deep tabular models in several ways, depending on the availability of pre-trained models and the similarity of the source and target tasks.

For example, a pre-trained model on a dataset of customer demographics and purchase history can be used as a feature extractor for a new task of predicting customer churn. The pre-trained model can extract useful features such as age, gender, purchase frequency, and product preferences, which can be used as input to a new model for predicting churn.

Supervised versus self-supervised pre-training

Supervised learning involves training a model using labeled data, where the input and output are both known. The goal is for the model to learn to predict the output for new inputs. The model is trained using a loss function that measures the difference between the predicted output and the true output. Examples of supervised learning tasks include image classification, object detection, and natural language processing.

Self-supervised training involves training a model by allowing it to learn the data representation from unlabeled data. Instead of using manually labeled data as input, the model creates its own labels. As a result, the model can learn useful features and representations of the data, which can be transferred to downstream tasks with a smaller amount of labeled data. Some examples of self-supervised training are speech recognition and computer vision.

Fine-tuning

Fine-tuning involves using a pre-trained model as a starting point and training the model on a new task. Fine-tuning can be done by updating some or all of the model’s parameters using the new task’s dataset. For example, a pre-trained model on a dataset of medical images can be fine-tuned for a new task of diagnosing lung cancer. The pre-trained model can be initialized with the learned weights and fine-tuned on a dataset of lung CT scans, where the model learns to identify patterns specific to lung cancer.

Multi-task learning

Multi-task learning involves training a model on multiple related tasks simultaneously. This technique can be useful in scenarios where the tasks share some common features or have related objectives. For example, a model can be trained on a dataset of customer purchase history and demographic data to predict customer churn and recommend new products. In this scenario, the model learns to capture the patterns that are common to both tasks, such as customer preferences and purchase behavior, while also learning to perform specific tasks.

How to decide which pre-training strategy to use in a specific domain

When to use supervised or self-supervised pre-training strategies depends on a number of factors.

There are times when supervised pre-training outperforms self-supervised pre-training (and vice versa). For example, when the person performing the training has a supervised task with a large enough number of labels, supervised pre-training may result in better performance. However, if they don’t have a large number of labels, self-supervised pre-training might outperform the supervised approach.

Transfer learning with deep tabular models

This study, done in partnership with Capital One and researchers at the University of Washington, University of Maryland, and New York University, shows that transfer learning is an effective technique for improving the performance of deep tabular models. By leveraging pre-trained models and transferring knowledge learned from related tasks, practitioners can significantly reduce training time and improve the predictive performance of tabular models.

Bayan Bruss, Sr. Director, Applied ML Research at Capital One, contributed to the academic paper and collaborated with the team to bring this leading machine learning technique to the tabular domain. Their publication and acceptance to ICLR 2023, a renowned machine learning research conference, is a significant achievement and indicates the importance of their contributions to the field.

Practitioners can read a summary of transfer learning with deep tabular models and find out why transfer learning stands to challenge the traditionally dominant GBDT models. As transfer learning techniques continue to evolve, they will undoubtedly play a crucial role in advancing the field of deep tabular modeling.

Originally published at https://www.capitalone.com.

Authored by Bayan Bruss, VP, Machine Learning Engineering

Bayan Bruss leads the Applied AI Research team at Capital One. His team aims to accelerate the adoption of academic and industry research in production systems. His team is currently focused on Graph Machine Learning, Foundation Models, Sequential Models, Machine Learning for Data and Privacy and Explainable AI. Prior to Capital One Bayan has over a decade of experience in academia, startups and consulting. He has participated in the organizing committees and program committees of several conferences and workshops at ICML, KDD, ICAIF, and NeurIPs. He holds an Adjunct Position at Georgetown University.

DISCLOSURE STATEMENT: © 2023 Capital One. Opinions are those of the individual author. Unless noted otherwise in this post, Capital One is not affiliated with, nor endorsed by, any of the companies mentioned. All trademarks and other intellectual property used or displayed are property of their respective owners. Capital One is not responsible for the content or privacy policies of any linked third-party sites.

--

--

Capital One Tech
Capital One Tech

From our founding, we’ve used tech to change the banking industry. Today, our innovations are making banking better for tens of millions of our customers.