HyperTab: Hypernetwork Approach for Deep Learning on Small Tabular Datasets

`pip install hypertab` is all you need

Witold Wydmanski
4 min readOct 11, 2023

Overview of the article:

  1. Why are tabular data so interesting?
  2. Introduction to hypernetworks
  3. What is HyperTab?
  4. Why use HyperTab?
  5. How to use HyperTab?
  6. How does HyperTab perform?

Why are tabular data so interesting?

Different types of datasets. Source: [parrot], [time series]

There are many different types of data —real world images, time series, natural language, and, of course, tabular. The popularity of the last ones are best reflected by Kaggle statistics — 6 688 of available datasets are tagged as ”tabular”, 4 908 datasets contain the tag ”image” and 178 datasets are tagged as ”text”. However, they are extremely difficult to work with.

Why is that?

Well, because contrary to, for example, real-world images, there is no single domain of tabular data. Each table that you may come across will probably come from a totally different distribution, and will be governed by totally different laws.

So how can we deal with it?

Using HyperTab!

Introduction to hypernetworks

However, let’s start from the beginning — an introduction to hypernetworks. A hypernetwork is a neural network that generates the weights of another neural network, called the target network. This is an example of metalearning framework, which sees the data split between tasks.

Metalearning framework. Hypernetwork is given task data, and creates a target network which is specialized in solving this task. Source: own creation

What is a task, you might ask?

Let’s imagine that you are tasked with a simple task — distinguishing between cats and dogs. This task is so simple, says your boss, how many samples per class could you need. 10? 20? I will give you 20. And so, you are given 20 samples. However, both you and I know that 40 samples is not enough to train a regular CNN. And this is where metalearning comes to the rescue.

Of course, you still need more data than 40 samples — so let’s say that you also have access to 20 images of tigers, 20 images of lions, and so on, and so on. You don’t want to create a full imagenet classifier, so you can’t create a CNN using those data. Metalearning and hypernetworks, however, are able to utilize this miscellaneous data while keeping the final classifier as simple as it gets.

In this example our task data will be a set of images of two different classes, be it cats and dogs or tigers and lions. The task of the hypernetwork will be to create a simple classifier putting each X into one of two classes.

Sounds intuitive? So let’s go into HyperTab.

What is HyperTab?

HyperTab is a hypernetwork-based approach to solving the classification problem of small tabular datasets.

HyperTab uses a hypernetwork to create an ensemble of target networks, where each target network is specialized to process a specific lower-dimensional view of the data. The lower-dimensional views are obtained by randomly selecting a subset of features from the original data, which plays the role of data augmentation.

Source: original paper

So, from the metalearning point of view, we give the hypernetwork a task data (binary mask of features) saying: OK Mr Hypernetwork, I want you to create a target network that will specialize in predictions using only features no 1, 4, and 6.

And hypernetwork obeys — it creates a separate, specialized target network for each binary mask that we feed it.

Finally, when it’s time to do a prediction, we can ditch the hypernetwork and plainly average outputs of each target network to create our final prediction.

Why use HyperTab?

HyperTab has several advantages over existing methods for tabular data analysis. First, HyperTab can effectively handle small datasets, where the number of samples is much smaller than the number of features. This is because HyperTab can generate multiple views of each sample by feature subsetting, which increases the diversity and quantity of the training data, effectively implementing a non-domain-specific tabular data augmentation.

Second, because of it’s non-dependence on specific features. Let’s go into a scenario in which you’re doing some inference on readings of some sensors. It’s easy to imagine one of the sensors breaking. However, by just removing target networks that are dependent on this one sensor, Hypertab will be able to proceed without any further tuning.

How to use HyperTab?

from hypertab import HyperTabClassifier

model = HyperTabClassifier
model.fit(X_train, y_train)

And that’s it.

How does HyperTab perform?

HyperTab has been evaluated on more than 40 tabular datasets from different domains and compared with state-of-the-art methods, such as Random Forests, XGBoost, Fully Connected Networks with Dropout, and Neural Oblivious Decision Ensembles (NODE). The results show that HyperTab consistently outperforms other methods on small datasets (with statistically significant differences) and scores comparable to them on larger datasets.

Results of comparing HyperTab to other algorithms on small (<1k samples) and large tabular datasets. Source: original paper

If you are interested in learning more about HyperTab, you can read our paper here or check out our Python package here. HyperTab is an innovative and effective technique for building deep learning models for small tabular datasets that combines feature subsetting augmentations with neural network ensembles using hypernetworks. We hope you find it useful and inspiring for your own projects!

WRITER at MLearning.ai /AI Agents LLM / AI 3D ART / 80+ GPT4 V

--

--

Witold Wydmanski

PhD candidate in ML at GMUM UJ, bioinformatician at MCB UJ