Generative Predictive Control
February 20, 2025 ยท View on GitHub
This repository contains code for the paper "Generative Predictive Control: Flow Matching Policies for Dynamic and Difficult-to-Demonstrate Tasks" by Vince Kurtz and Joel Burdick. Video summary.
This includes code for training and testing flow-matching policies on each of the robot systems shown below:
Generative Predictive Control (GPC) is a supervised learning framework for training flow-matching policies on tasks that are difficult to demonstrate but easy to simulate. GPC alternates between generating training data with sampling-based predictive control, fitting a generative model to the data, and using the generative model to improve the sampling distribution.
Install (Conda)
Clone and create the conda env (first time only):
git clone https://github.com/vincekurtz/gpc.git
cd gpc
conda env create -f environment.yml
Enter the conda env:
conda activate gpc
Install the package and dependencies:
pip install -e .
Examples
Various examples can be found in the examples directory. For
example, to train a cart-pole swingup policy using GPC, run:
python examples/cart_pole.py train
This will train a flow-matching policy and save it to
/tmp/cart_pole_policy.pkl. To run an interactive simulation with the trained
policy, run
python examples/cart_pole.py test
To see other command-line options, run
python examples/cart_pole.py --help
Using a Different Robot Model
To try GPC on your own robot or task, you will need to:
- Define a Hydrax task that encodes the cost function and system dynamics.
- Define a training environment that inherits from
gpc.envs.base.TrainingEnv. This must implement thereset,get_obs, andobservation_sizemethods. For example:
class MyCustomEnv(TrainingEnv):
def __init__(self):
super().__init__(task=MyCustomHydraxTask(), episode_length=100)
def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
"""Reset the simulator to start a new episode."""
...
return new_data
def get_obs(self, data: mjx.Data) -> jax.Array:
"""Get the observation from the simulator."""
...
return jax.array([obs1, obs2, ...])
@property
def observation_size(self) -> int:
"""Return the size of the observation vector."""
...
Then you should be able to run gpc.training.train to train a flow-matching
policy, and gpc.testing.test_interactive to run an interactive simulation with
the trained policy. See the environments in gpc.envs for examples
and additional details.
Citation
@article{kurtz2025generative,
title={Generative Predictive Control: Flow Matching Policies for Dynamic and Difficult-to-Demonstrate Task},
author={Kurtz, Vince and Burdick, Joel},
journal={arXiv preprint arXiv:2502.13406},
year={2025},
}





