README.md

July 9, 2024 ยท View on GitHub

rlbase

This is a codebase that implements simple reinforcement learning algorithms in JAX. It also has support for several environments. The idea is to have solid single-file implementations of various RL algorithms for research use. This codebase contains both online and offline methods.

Online Algorithms Implemented:

Offline Algorithms Implemented:

Environments Supported:

  • (Online) Gym Mujoco Locomotion: HalfCheetah-v2, CartPole-v1, etc
  • (Online) Deepmind Control: cheetah_run, pendulum_swingup, etc
  • (Offline) D4RL Mujoco Locomotion: halfcheetah-medium-expert-v2, etc
  • (Offline) D4RL AntMaze + Goal Conditioned: antmaze-large-diverse-v2, gc-antmaze-large-diverse-v2
  • (Offline) ExORL: exorl_cheetah_walk, etc
  • See envs/env_helper.py for full list

Instllation

For the cleanest installation, create a conda environment:

conda env create -f deps/environment.yml

You can also refer to the singularity script in deps/base_container.def for full reproducability.

Reproduction

We've provided a set of stable results comparing each algorithm to a reference implementation. For full training curves, see the wandb reports for online results and the wandb reports for offline results.

You can reproduce these results using the commands available at run_baselines.py. The basic starting point is to run the individual file, e.g.

python algs_online/ppo.py --env_name walker_walk --agent.gamma 0.99

Offline Results

EnvBest Performance (ours)Best Original Performance (reference paper)
exorl_cheetah_run257.5 (IQL-DDPG)~250 (TD3) source (exorl)
exorl_walker_run471.9 (IQL-DDPG)~200 (TD3) source (exorl)
halfcheetah-medium-expert-v283.8 (IQL)90.7 (TD3+BC) source (iql)
walker2d-medium-expert-v2106.8 (BC)110.1 (TD3+BC) source (iql)
hopper-medium-expert-v298.9 (IQL)98.7 (CQL) source (iql)
gc-antmaze-large-diverse-v252.5 (IQL)50.7 (IQL) source (hiql)
gc-maze2d-large-v197.5 (IQL)N/A

Online Results

EnvBest Performance (ours)Best Original Performance (reference paper)
HalfCheetah-v211029 (SAC)12138.8 (SAC) source (tianshou)
Walker2d-v25101.8 (SAC-Tianshoulike)5007 (SAC)source (tianshou)
Hopper-v22714.4 (REDQ)3542.2 (SAC)source (tianshou)
cheetah_run918.9 (REDQ)800 (SAC) source (pytorch-sac)
walker_run835.7 (TD3)700 (SAC) source (pytorch-sac)
hopper_hop474.9 (TD3)210 (SAC) source (pytorch-sac)
quadruped_run920.8 (TD3)700 (SAC) source (pytorch-sac)
humanoid_run211.8 (REDQ)90 (SAC) source (pytorch-sac)
pendulum_swingup790.2 (SAC)920 (SAC) source (pytorch-sac)

History

This code is based largely off the jaxrl_m repo, and takes inspiration also from jaxrl and cleanrl.