Generative Predictive Control

February 20, 2025 ยท View on GitHub

This repository contains code for the paper "Generative Predictive Control: Flow Matching Policies for Dynamic and Difficult-to-Demonstrate Tasks" by Vince Kurtz and Joel Burdick. Video summary.

This includes code for training and testing flow-matching policies on each of the robot systems shown below:

Generative Predictive Control (GPC) is a supervised learning framework for training flow-matching policies on tasks that are difficult to demonstrate but easy to simulate. GPC alternates between generating training data with sampling-based predictive control, fitting a generative model to the data, and using the generative model to improve the sampling distribution.

Install (Conda)

Clone and create the conda env (first time only):

git clone https://github.com/vincekurtz/gpc.git
cd gpc
conda env create -f environment.yml

Enter the conda env:

conda activate gpc

Install the package and dependencies:

pip install -e .

Examples

Various examples can be found in the examples directory. For example, to train a cart-pole swingup policy using GPC, run:

python examples/cart_pole.py train

This will train a flow-matching policy and save it to /tmp/cart_pole_policy.pkl. To run an interactive simulation with the trained policy, run

python examples/cart_pole.py test

To see other command-line options, run

python examples/cart_pole.py --help

Using a Different Robot Model

To try GPC on your own robot or task, you will need to:

  1. Define a Hydrax task that encodes the cost function and system dynamics.
  2. Define a training environment that inherits from gpc.envs.base.TrainingEnv. This must implement the reset, get_obs, and observation_size methods. For example:
class MyCustomEnv(TrainingEnv):
    def __init__(self):
        super().__init__(task=MyCustomHydraxTask(), episode_length=100)

    def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
        """Reset the simulator to start a new episode."""
        ...
        return new_data

    def get_obs(self, data: mjx.Data) -> jax.Array:
        """Get the observation from the simulator."""
        ...
        return jax.array([obs1, obs2, ...])

    @property
    def observation_size(self) -> int:
        """Return the size of the observation vector."""
        ...

Then you should be able to run gpc.training.train to train a flow-matching policy, and gpc.testing.test_interactive to run an interactive simulation with the trained policy. See the environments in gpc.envs for examples and additional details.

Citation

@article{kurtz2025generative,
  title={Generative Predictive Control: Flow Matching Policies for Dynamic and Difficult-to-Demonstrate Task},
  author={Kurtz, Vince and Burdick, Joel},
  journal={arXiv preprint arXiv:2502.13406},
  year={2025},
}