Learning Robotic Video Dynamics with Heterogeneous Masked Autoregression

February 7, 2025 · View on GitHub

HF Models PyTorch Python

Arxiv Website Demo


Getting Started

We require Python 3.10 or later. This code was tested with Python 3.10.12.

# Install dependencies and download data
./build.sh

# Source the Python environment
source venv/bin/activate

Demo

  1. Run the interactive demo: python -m sim.app
  2. Select a prompt initial image from the gallery
  3. Interact with arrow keys

Dataset Encoding/Tokenization

Detailed Commands The one-line version of encoding one dataset is as the following example: ```shell python -m datasets.encode_openx_dataset --episode_cnt 1000000 --dataset_name kaist_nonprehensile_converted_externally_to_rlds --data_split train --root_dir data ```

The dataset can be partitioned into multiple shards, and then one GPU can process each shard. Shards are useful not only for parallelization but also for saving progress in case a process is interrupted.

Example encoding droid with 2 GPUs:

# Process 1
set -e
for ((i = 0; i < 64; i += 2)); do
    CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name droid --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
done

# Process 2
set -e
for ((i = 1; i < 64; i += 2)); do
    CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name droid --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
done

Then, merge the shards into one dataset. Merging does not require all shards to be generated, missing shards will be skipped.

Example merge:

# Modify SHARD_DATA_FORMAT in dataset/merge_shards.py, not CLI arg b/c it's a format str
python datasets/merge_shards.py --out_data_dir merged_data/droid --num_shards 64

Non-OpenX datasets require using datasets.encode_openx_dataset instead of datasets.encode_openx_dataset, and currently the only such supported dataset is egoexo4d. For every external dataset, the code for that dataset must be modified to allow iterating over only a subset (i.e. shard) of the dataset.

Example non-OpenX dataset:

python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 100 --curr_shard_rank 0 --root_dir sharded_data

To train and evaluate the soft tokens, we need to follow the same script but save dataset using non-VQ encoders. Similarly, to evaluate against the raw images, we need to use the script and save dataset without encoders --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'.

Pre-Training Scripts (Single Dataset)

# Single Dataset Training, Generation, and Evaluation
python -m hma.train_multi --output_dir data/model

python hma/generate.py  --checkpoint_dir  data/model/step_100/  --val_data_dir data/kaist_nonprehensile_converted_externally_to_rlds_magvit_max1000000_val

python hma/visualize.py   --token_dir data/genie_generated

python hma/evaluate.py  --checkpoint_dir  data/model/step_100/  --val_data_dir data/kaist_nonprehensile_converted_externally_to_rlds_magvit_max1000000_val --use_tokeniz
ed_images

Pre-Training Scripts (Multiple Datasets)


# Debug Run
bash experiments/scripts/run_debug.sh

# VQ tokens Model
bash experiments/scripts/discrete_model/run_40datasets_waction.sh

# Soft tokens Model
bash experiments/scripts/continuous_model/run_30datasets_mar_waction.sh

Post-Training Scripts

  1. Finetuning on language table dataset bash experiments/scripts/posttraining_scripts/run_langtable_finetuning.sh

Checkpoints

You can find pretrained HMA checkpoints here. At the moment we provide the following model versions:

ModelSize
HMA-MagVit362M Params
HMA-MAR1B Params

Evaluation

Example on discrete model evaluation:

accelerate launch hma/evaluate.py \
        --checkpoint_dir "data/${RUN_NAME}/final_checkpt" \
        --val_data_dir "data/${dataset}_magvit_traj1000000_val" \
        --wandb_run_name "${RUN_NAME}"'

bash experiments/scripts/eval_action_scripts/run_evaluation_discrete.sh $MODEL $DATASET

Example on continuous model evaluation:

accelerate launch hma/evaluate_feature.py \
        --checkpoint_dir "data/${RUN_NAME}/final_checkpt" \
        --val_data_dir "data/${dataset}_magvit_traj1000000_val" \
        --wandb_run_name "${RUN_NAME}"

bash experiments/scripts/eval_action_scripts/run_evaluation_continuous.sh $MODEL $DATASET

Note

  1. Training, evaluation, and visualization are stored in different wandb projects.
  2. Code quality: tired grad student.

File Structures

├── ...
├── HMA
|   |── data 			# cached token datasets and model checkpoints
|   |── hma 			# main modeling code
|   |   |── model       # model related scripts
|   |   |── evaluate.py   # evaluate a trained model
|   |   |── generate.py    # generate tokens from trained model
|   |   |── train_multi.py # train on multiple datasets jointly
|   |   |── visualize.py # visualize generated tokens
|   |── sim 			# simulation related codebase
|   |── experiments
|   |   |── datasplit # dataset split
|   |   |── scripts # ablation and training scripts.
|   |── external 		# common utility
└── ...

🕹️ Citation

If you find HMA useful in your research, please consider citing:

@inproceedings{wang2025hma,
author    = {Lirui Wang, Kevin Zhao, Chaoqi Liu, Xinlei Chen},
title     = {Learning Robotic Video Dynamics with Heterogeneous Masked Autoregression},
booktitle = {Arxiv},
year      = {2025}
}

Acknowledgement

  1. 1xGPT
  2. MAR
  3. HPT

Contact

If you have any questions, feel free to contact me through email (liruiw@mit.edu). Enjoy!