README.md

March 22, 2026 · View on GitHub



A modular deep learning framework for building neural network models on heterogeneous tabular data.


arXiv PyPI - Python Version PyPI Version Testing Status Docs Status Contributing Slack

Documentation | Paper

PyTorch Frame is a deep learning extension for PyTorch, designed for heterogeneous tabular data with different column types, including numerical, categorical, time, text, and images. It offers a modular framework for implementing existing and future methods. The library features methods from state-of-the-art models, user-friendly mini-batch loaders, benchmark datasets, and interfaces for custom data integration.

PyTorch Frame democratizes deep learning research for tabular data, catering to both novices and experts alike. Our goals are:

  1. Facilitate Deep Learning for Tabular Data: Historically, tree-based models (e.g., GBDT) excelled at tabular learning but had notable limitations, such as integration difficulties with downstream models, and handling complex column types, such as texts, sequences, and embeddings. Deep tabular models are promising to resolve the limitations. We aim to facilitate deep learning research on tabular data by modularizing its implementation and supporting the diverse column types.

  2. Integrates with Diverse Model Architectures like Large Language Models: PyTorch Frame supports integration with a variety of different architectures including LLMs. With any downloaded model or embedding API endpoint, you can encode your text data with embeddings and train it with deep learning models alongside other complex semantic types. We support the following (but not limited to):

OpenAI
OpenAI Embedding Code Example
Cohere
Cohere Embed v3 Code Example
Hugging Face
Hugging Face Code Example
Voyage AI
Voyage AI Code Example

Library Highlights

PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include:

  • Diverse column types: PyTorch Frame supports learning across various column types: numerical, categorical, multicategorical, text_embedded, text_tokenized, timestamp, image_embedded, and embedding. See here for the detailed tutorial.
  • Modular model design: Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the architecture overview.
  • Models Implements many state-of-the-art deep tabular models as well as strong GBDTs (XGBoost, CatBoost, and LightGBM) with hyper-parameter tuning.
  • Datasets: Comes with a collection of readily-usable tabular datasets. Also supports custom datasets to solve your own problem. We benchmark deep tabular models against GBDTs.
  • PyTorch integration: Integrates effortlessly with other PyTorch libraries, facilitating end-to-end training of PyTorch Frame with downstream PyTorch models. For example, by integrating with PyG, a PyTorch library for GNNs, we can perform deep learning over relational databases. Learn more in RelBench and example code.

Architecture Overview

Models in PyTorch Frame follow a modular design of FeatureEncoder, TableConv, and Decoder, as shown in the figure below:

In essence, this modular setup empowers users to effortlessly experiment with myriad architectures:

  • Materialization handles converting the raw pandas DataFrame into a TensorFrame that is amenable to Pytorch-based training and modeling.
  • FeatureEncoder encodes TensorFrame into hidden column embeddings of size [batch_size, num_cols, channels].
  • TableConv models column-wise interactions over the hidden embeddings.
  • Decoder generates embedding/prediction per row.

Quick Tour

In this quick tour, we showcase the ease of creating and training a deep tabular model with only a few lines of code.

Build and train your own deep tabular model

As an example, we implement a simple ExampleTransformer following the modular architecture of Pytorch Frame. In the example below:

  • self.encoder maps an input TensorFrame to an embedding of size [batch_size, num_cols, channels].
  • self.convs iteratively transforms the embedding of size [batch_size, num_cols, channels] into an embedding of the same size.
  • self.decoder pools the embedding of size [batch_size, num_cols, channels] into [batch_size, out_channels].
from torch import Tensor
from torch.nn import Linear, Module, ModuleList

from torch_frame import TensorFrame, stype
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeWiseFeatureEncoder,
)

class ExampleTransformer(Module):
    def __init__(
        self,
        channels, out_channels, num_layers, num_heads,
        col_stats, col_names_dict,
    ):
        super().__init__()
        self.encoder = StypeWiseFeatureEncoder(
            out_channels=channels,
            col_stats=col_stats,
            col_names_dict=col_names_dict,
            stype_encoder_dict={
                stype.categorical: EmbeddingEncoder(),
                stype.numerical: LinearEncoder()
            },
        )
        self.convs = ModuleList([
            TabTransformerConv(
                channels=channels,
                num_heads=num_heads,
            ) for _ in range(num_layers)
        ])
        self.decoder = Linear(channels, out_channels)

    def forward(self, tf: TensorFrame) -> Tensor:
        x, _ = self.encoder(tf)
        for conv in self.convs:
            x = conv(x)
        out = self.decoder(x.mean(dim=1))
        return out

To prepare the data, we can quickly instantiate a pre-defined dataset and create a PyTorch-compatible data loader as follows:

from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader

dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
train_dataset = dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,
                          shuffle=True)

Then, we just follow the standard PyTorch training procedure to optimize the model parameters. That's it!

import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(
    channels=32,
    out_channels=dataset.num_classes,
    num_layers=2,
    num_heads=8,
    col_stats=train_dataset.col_stats,
    col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(50):
    for tf in train_loader:
        tf = tf.to(device)
        pred = model.forward(tf)
        loss = F.cross_entropy(pred, tf.y)
        optimizer.zero_grad()
        loss.backward()

Implemented Deep Tabular Models

We list currently supported deep tabular models:

In addition, we implemented XGBoost, CatBoost, and LightGBM examples with hyperparameter-tuning using Optuna for users who'd like to compare their model performance with GBDTs.

Benchmark

We benchmark recent tabular deep learning models against GBDTs over diverse public datasets with different sizes and task types.

The following chart shows the performance of various models on small regression datasets, where the row represents the model names and the column represents dataset indices (we have 13 datasets here). For more results on classification and larger datasets, please check the benchmark documentation.

Model Namedataset_0dataset_1dataset_2dataset_3dataset_4dataset_5dataset_6dataset_7dataset_8dataset_9dataset_10dataset_11dataset_12
XGBoost0.250±0.0000.038±0.0000.187±0.0000.475±0.0000.328±0.0000.401±0.0000.249±0.0000.363±0.0000.904±0.0000.056±0.0000.820±0.0000.857±0.0000.418±0.000
CatBoost0.265±0.0000.062±0.0000.128±0.0000.336±0.0000.346±0.0000.443±0.0000.375±0.0000.273±0.0000.881±0.0000.040±0.0000.756±0.0000.876±0.0000.439±0.000
LightGBM0.253±0.0000.054±0.0000.112±0.0000.302±0.0000.325±0.0000.384±0.0000.295±0.0000.272±0.0000.877±0.0000.011±0.0000.702±0.0000.863±0.0000.395±0.000
Trompt0.261±0.0030.015±0.0050.118±0.0010.262±0.0010.323±0.0010.418±0.0030.329±0.0090.312±0.002OOM0.008±0.0010.779±0.0060.874±0.0040.424±0.005
ResNet0.288±0.0060.018±0.0030.124±0.0010.268±0.0010.335±0.0010.434±0.0040.325±0.0120.324±0.0040.895±0.0050.036±0.0020.794±0.0060.875±0.0040.468±0.004
FTTransformerBucket0.325±0.0080.096±0.0050.360±0.3540.284±0.0050.342±0.0040.441±0.0030.345±0.0070.339±0.003OOM0.105±0.0110.807±0.0100.885±0.0080.468±0.006
ExcelFormer0.262±0.0040.099±0.0030.128±0.0000.264±0.0030.331±0.0030.411±0.0050.298±0.0120.308±0.007OOM0.011±0.0010.785±0.0110.890±0.0030.431±0.006
FTTransformer0.335±0.0100.161±0.0220.140±0.0020.277±0.0040.335±0.0030.445±0.0030.361±0.0180.345±0.005OOM0.106±0.0120.826±0.0050.896±0.0070.461±0.003
TabNet0.279±0.0030.224±0.0160.141±0.0100.275±0.0020.348±0.0030.451±0.0070.355±0.0300.332±0.0040.992±0.1820.015±0.0020.805±0.0140.885±0.0130.544±0.011
TabTransformer0.624±0.0030.229±0.0030.369±0.0050.340±0.0040.388±0.0020.539±0.0030.619±0.0050.351±0.0010.893±0.0050.431±0.0010.819±0.0020.886±0.0050.545±0.004

We see that some recent deep tabular models were able to achieve competitive model performance to strong GBDTs (despite being 5--100 times slower to train). Making deep tabular models even more performant with less compute is a fruitful direction for future research.

We also benchmark different text encoders on a real-world tabular dataset (Wine Reviews) with one text column. The following table shows the performance:

Test AccMethodModel NameSource
0.7926Pre-trainedsentence-transformers/all-distilroberta-v1 (125M # params)Hugging Face
0.7998Pre-trainedembed-english-v3.0 (dimension size: 1024)Cohere
0.8102Pre-trainedtext-embedding-ada-002 (dimension size: 1536)OpenAI
0.8147Pre-trainedvoyage-01 (dimension size: 1024)Voyage AI
0.8203Pre-trainedintfloat/e5-mistral-7b-instruct (7B # params)Hugging Face
0.8230LoRA FinetuneDistilBERT (66M # params)Hugging Face

The benchmark script for Hugging Face text encoders is in this file and for the rest of text encoders is in this file.

Installation

PyTorch Frame is available for Python 3.10 to Python 3.14.

pip install pytorch-frame

See the installation guide for other options.

Cite

If you use PyTorch Frame in your work, please cite our paper (Bibtex below).

@article{hu2024pytorch,
  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
  journal={arXiv preprint arXiv:2404.00776},
  year={2024}
}